CommandProviderWrapper.cs
1 // Copyright (c) Microsoft Corporation 2 // The Microsoft Corporation licenses this file to you under the MIT license. 3 // See the LICENSE file in the project root for more information. 4 5 using ManagedCommon; 6 using Microsoft.CmdPal.Core.Common.Services; 7 using Microsoft.CmdPal.Core.ViewModels; 8 using Microsoft.CmdPal.Core.ViewModels.Models; 9 using Microsoft.CmdPal.UI.ViewModels.Services; 10 using Microsoft.CommandPalette.Extensions; 11 using Microsoft.Extensions.DependencyInjection; 12 13 using Windows.Foundation; 14 15 namespace Microsoft.CmdPal.UI.ViewModels; 16 17 public sealed class CommandProviderWrapper 18 { 19 public bool IsExtension => Extension is not null; 20 21 private readonly bool isValid; 22 23 private readonly ExtensionObject<ICommandProvider> _commandProvider; 24 25 private readonly TaskScheduler _taskScheduler; 26 27 private readonly ICommandProviderCache? _commandProviderCache; 28 29 public TopLevelViewModel[] TopLevelItems { get; private set; } = []; 30 31 public TopLevelViewModel[] FallbackItems { get; private set; } = []; 32 33 public string DisplayName { get; private set; } = string.Empty; 34 35 public IExtensionWrapper? Extension { get; } 36 37 public CommandPaletteHost ExtensionHost { get; private set; } 38 39 public event TypedEventHandler<CommandProviderWrapper, IItemsChangedEventArgs>? CommandsChanged; 40 41 public string Id { get; private set; } = string.Empty; 42 43 public IconInfoViewModel Icon { get; private set; } = new(null); 44 45 public CommandSettingsViewModel? Settings { get; private set; } 46 47 public bool IsActive { get; private set; } 48 49 public string ProviderId => string.IsNullOrEmpty(Extension?.ExtensionUniqueId) ? Id : Extension.ExtensionUniqueId; 50 51 public CommandProviderWrapper(ICommandProvider provider, TaskScheduler mainThread) 52 { 53 // This ctor is only used for in-proc builtin commands. So the Unsafe! 54 // calls are pretty dang safe actually. 55 _commandProvider = new(provider); 56 _taskScheduler = mainThread; 57 58 // Hook the extension back into us 59 ExtensionHost = new CommandPaletteHost(provider); 60 _commandProvider.Unsafe!.InitializeWithHost(ExtensionHost); 61 62 _commandProvider.Unsafe!.ItemsChanged += CommandProvider_ItemsChanged; 63 64 isValid = true; 65 Id = provider.Id; 66 DisplayName = provider.DisplayName; 67 Icon = new(provider.Icon); 68 Icon.InitializeProperties(); 69 70 // Note: explicitly not InitializeProperties()ing the settings here. If 71 // we do that, then we'd regress GH #38321 72 Settings = new(provider.Settings, this, _taskScheduler); 73 74 Logger.LogDebug($"Initialized command provider {ProviderId}"); 75 } 76 77 public CommandProviderWrapper(IExtensionWrapper extension, TaskScheduler mainThread, ICommandProviderCache commandProviderCache) 78 { 79 _taskScheduler = mainThread; 80 _commandProviderCache = commandProviderCache; 81 82 Extension = extension; 83 ExtensionHost = new CommandPaletteHost(extension); 84 if (!Extension.IsRunning()) 85 { 86 throw new ArgumentException("You forgot to start the extension. This is a CmdPal error - we need to make sure to call StartExtensionAsync"); 87 } 88 89 var extensionImpl = extension.GetExtensionObject(); 90 var providerObject = extensionImpl?.GetProvider(ProviderType.Commands); 91 if (providerObject is not ICommandProvider provider) 92 { 93 throw new ArgumentException("extension didn't actually implement ICommandProvider"); 94 } 95 96 _commandProvider = new(provider); 97 98 try 99 { 100 var model = _commandProvider.Unsafe!; 101 102 // Hook the extension back into us 103 model.InitializeWithHost(ExtensionHost); 104 model.ItemsChanged += CommandProvider_ItemsChanged; 105 106 isValid = true; 107 108 Logger.LogDebug($"Initialized extension command provider {Extension.PackageFamilyName}:{Extension.ExtensionUniqueId}"); 109 } 110 catch (Exception e) 111 { 112 Logger.LogError("Failed to initialize CommandProvider for extension."); 113 Logger.LogError($"Extension was {Extension!.PackageFamilyName}"); 114 Logger.LogError(e.ToString()); 115 } 116 117 isValid = true; 118 } 119 120 private ProviderSettings GetProviderSettings(SettingsModel settings) 121 { 122 return settings.GetProviderSettings(this); 123 } 124 125 public async Task LoadTopLevelCommands(IServiceProvider serviceProvider, WeakReference<IPageContext> pageContext) 126 { 127 if (!isValid) 128 { 129 IsActive = false; 130 RecallFromCache(); 131 return; 132 } 133 134 var settings = serviceProvider.GetService<SettingsModel>()!; 135 136 var providerSettings = GetProviderSettings(settings); 137 IsActive = providerSettings.IsEnabled; 138 if (!IsActive) 139 { 140 RecallFromCache(); 141 return; 142 } 143 144 var displayInfoInitialized = false; 145 try 146 { 147 var model = _commandProvider.Unsafe!; 148 149 Task<ICommandItem[]> loadTopLevelCommandsTask = new(model.TopLevelCommands); 150 loadTopLevelCommandsTask.Start(); 151 var commands = await loadTopLevelCommandsTask.ConfigureAwait(false); 152 153 // On a BG thread here 154 var fallbacks = model.FallbackCommands(); 155 156 if (model is ICommandProvider2 two) 157 { 158 UnsafePreCacheApiAdditions(two); 159 } 160 161 Id = model.Id; 162 DisplayName = model.DisplayName; 163 Icon = new(model.Icon); 164 Icon.InitializeProperties(); 165 displayInfoInitialized = true; 166 167 // Update cached display name 168 if (_commandProviderCache is not null && Extension?.ExtensionUniqueId is not null) 169 { 170 _commandProviderCache.Memorize(Extension.ExtensionUniqueId, new CommandProviderCacheItem(model.DisplayName)); 171 } 172 173 // Note: explicitly not InitializeProperties()ing the settings here. If 174 // we do that, then we'd regress GH #38321 175 Settings = new(model.Settings, this, _taskScheduler); 176 177 // We do need to explicitly initialize commands though 178 InitializeCommands(commands, fallbacks, serviceProvider, pageContext); 179 180 Logger.LogDebug($"Loaded commands from {DisplayName} ({ProviderId})"); 181 } 182 catch (Exception e) 183 { 184 Logger.LogError("Failed to load commands from extension"); 185 Logger.LogError($"Extension was {Extension!.PackageFamilyName}"); 186 Logger.LogError(e.ToString()); 187 188 if (!displayInfoInitialized) 189 { 190 RecallFromCache(); 191 } 192 } 193 } 194 195 private void RecallFromCache() 196 { 197 var cached = _commandProviderCache?.Recall(ProviderId); 198 if (cached is not null) 199 { 200 DisplayName = cached.DisplayName; 201 } 202 203 if (string.IsNullOrWhiteSpace(DisplayName)) 204 { 205 DisplayName = Extension?.PackageDisplayName ?? Extension?.PackageFamilyName ?? ProviderId; 206 } 207 } 208 209 private void InitializeCommands(ICommandItem[] commands, IFallbackCommandItem[] fallbacks, IServiceProvider serviceProvider, WeakReference<IPageContext> pageContext) 210 { 211 var settings = serviceProvider.GetService<SettingsModel>()!; 212 var providerSettings = GetProviderSettings(settings); 213 214 var makeAndAdd = (ICommandItem? i, bool fallback) => 215 { 216 CommandItemViewModel commandItemViewModel = new(new(i), pageContext); 217 TopLevelViewModel topLevelViewModel = new(commandItemViewModel, fallback, ExtensionHost, ProviderId, settings, providerSettings, serviceProvider, i); 218 topLevelViewModel.InitializeProperties(); 219 220 return topLevelViewModel; 221 }; 222 223 if (commands is not null) 224 { 225 TopLevelItems = commands 226 .Select(c => makeAndAdd(c, false)) 227 .ToArray(); 228 } 229 230 if (fallbacks is not null) 231 { 232 FallbackItems = fallbacks 233 .Select(c => makeAndAdd(c, true)) 234 .ToArray(); 235 } 236 } 237 238 private void UnsafePreCacheApiAdditions(ICommandProvider2 provider) 239 { 240 var apiExtensions = provider.GetApiExtensionStubs(); 241 Logger.LogDebug($"Provider supports {apiExtensions.Length} extensions"); 242 foreach (var a in apiExtensions) 243 { 244 if (a is IExtendedAttributesProvider command2) 245 { 246 Logger.LogDebug($"{ProviderId}: Found an IExtendedAttributesProvider"); 247 } 248 } 249 } 250 251 public override bool Equals(object? obj) => obj is CommandProviderWrapper wrapper && isValid == wrapper.isValid; 252 253 public override int GetHashCode() => _commandProvider.GetHashCode(); 254 255 private void CommandProvider_ItemsChanged(object sender, IItemsChangedEventArgs args) => 256 257 // We don't want to handle this ourselves - we want the 258 // TopLevelCommandManager to know about this, so they can remove 259 // our old commands from their own list. 260 // 261 // In handling this, a call will be made to `LoadTopLevelCommands` to 262 // retrieve the new items. 263 this.CommandsChanged?.Invoke(this, args); 264 }