/ src / modules / cmdpal / Microsoft.CmdPal.UI.ViewModels / CommandProviderWrapper.cs
CommandProviderWrapper.cs
  1  // Copyright (c) Microsoft Corporation
  2  // The Microsoft Corporation licenses this file to you under the MIT license.
  3  // See the LICENSE file in the project root for more information.
  4  
  5  using ManagedCommon;
  6  using Microsoft.CmdPal.Core.Common.Services;
  7  using Microsoft.CmdPal.Core.ViewModels;
  8  using Microsoft.CmdPal.Core.ViewModels.Models;
  9  using Microsoft.CmdPal.UI.ViewModels.Services;
 10  using Microsoft.CommandPalette.Extensions;
 11  using Microsoft.Extensions.DependencyInjection;
 12  
 13  using Windows.Foundation;
 14  
 15  namespace Microsoft.CmdPal.UI.ViewModels;
 16  
 17  public sealed class CommandProviderWrapper
 18  {
 19      public bool IsExtension => Extension is not null;
 20  
 21      private readonly bool isValid;
 22  
 23      private readonly ExtensionObject<ICommandProvider> _commandProvider;
 24  
 25      private readonly TaskScheduler _taskScheduler;
 26  
 27      private readonly ICommandProviderCache? _commandProviderCache;
 28  
 29      public TopLevelViewModel[] TopLevelItems { get; private set; } = [];
 30  
 31      public TopLevelViewModel[] FallbackItems { get; private set; } = [];
 32  
 33      public string DisplayName { get; private set; } = string.Empty;
 34  
 35      public IExtensionWrapper? Extension { get; }
 36  
 37      public CommandPaletteHost ExtensionHost { get; private set; }
 38  
 39      public event TypedEventHandler<CommandProviderWrapper, IItemsChangedEventArgs>? CommandsChanged;
 40  
 41      public string Id { get; private set; } = string.Empty;
 42  
 43      public IconInfoViewModel Icon { get; private set; } = new(null);
 44  
 45      public CommandSettingsViewModel? Settings { get; private set; }
 46  
 47      public bool IsActive { get; private set; }
 48  
 49      public string ProviderId => string.IsNullOrEmpty(Extension?.ExtensionUniqueId) ? Id : Extension.ExtensionUniqueId;
 50  
 51      public CommandProviderWrapper(ICommandProvider provider, TaskScheduler mainThread)
 52      {
 53          // This ctor is only used for in-proc builtin commands. So the Unsafe!
 54          // calls are pretty dang safe actually.
 55          _commandProvider = new(provider);
 56          _taskScheduler = mainThread;
 57  
 58          // Hook the extension back into us
 59          ExtensionHost = new CommandPaletteHost(provider);
 60          _commandProvider.Unsafe!.InitializeWithHost(ExtensionHost);
 61  
 62          _commandProvider.Unsafe!.ItemsChanged += CommandProvider_ItemsChanged;
 63  
 64          isValid = true;
 65          Id = provider.Id;
 66          DisplayName = provider.DisplayName;
 67          Icon = new(provider.Icon);
 68          Icon.InitializeProperties();
 69  
 70          // Note: explicitly not InitializeProperties()ing the settings here. If
 71          // we do that, then we'd regress GH #38321
 72          Settings = new(provider.Settings, this, _taskScheduler);
 73  
 74          Logger.LogDebug($"Initialized command provider {ProviderId}");
 75      }
 76  
 77      public CommandProviderWrapper(IExtensionWrapper extension, TaskScheduler mainThread, ICommandProviderCache commandProviderCache)
 78      {
 79          _taskScheduler = mainThread;
 80          _commandProviderCache = commandProviderCache;
 81  
 82          Extension = extension;
 83          ExtensionHost = new CommandPaletteHost(extension);
 84          if (!Extension.IsRunning())
 85          {
 86              throw new ArgumentException("You forgot to start the extension. This is a CmdPal error - we need to make sure to call StartExtensionAsync");
 87          }
 88  
 89          var extensionImpl = extension.GetExtensionObject();
 90          var providerObject = extensionImpl?.GetProvider(ProviderType.Commands);
 91          if (providerObject is not ICommandProvider provider)
 92          {
 93              throw new ArgumentException("extension didn't actually implement ICommandProvider");
 94          }
 95  
 96          _commandProvider = new(provider);
 97  
 98          try
 99          {
100              var model = _commandProvider.Unsafe!;
101  
102              // Hook the extension back into us
103              model.InitializeWithHost(ExtensionHost);
104              model.ItemsChanged += CommandProvider_ItemsChanged;
105  
106              isValid = true;
107  
108              Logger.LogDebug($"Initialized extension command provider {Extension.PackageFamilyName}:{Extension.ExtensionUniqueId}");
109          }
110          catch (Exception e)
111          {
112              Logger.LogError("Failed to initialize CommandProvider for extension.");
113              Logger.LogError($"Extension was {Extension!.PackageFamilyName}");
114              Logger.LogError(e.ToString());
115          }
116  
117          isValid = true;
118      }
119  
120      private ProviderSettings GetProviderSettings(SettingsModel settings)
121      {
122          return settings.GetProviderSettings(this);
123      }
124  
125      public async Task LoadTopLevelCommands(IServiceProvider serviceProvider, WeakReference<IPageContext> pageContext)
126      {
127          if (!isValid)
128          {
129              IsActive = false;
130              RecallFromCache();
131              return;
132          }
133  
134          var settings = serviceProvider.GetService<SettingsModel>()!;
135  
136          var providerSettings = GetProviderSettings(settings);
137          IsActive = providerSettings.IsEnabled;
138          if (!IsActive)
139          {
140              RecallFromCache();
141              return;
142          }
143  
144          var displayInfoInitialized = false;
145          try
146          {
147              var model = _commandProvider.Unsafe!;
148  
149              Task<ICommandItem[]> loadTopLevelCommandsTask = new(model.TopLevelCommands);
150              loadTopLevelCommandsTask.Start();
151              var commands = await loadTopLevelCommandsTask.ConfigureAwait(false);
152  
153              // On a BG thread here
154              var fallbacks = model.FallbackCommands();
155  
156              if (model is ICommandProvider2 two)
157              {
158                  UnsafePreCacheApiAdditions(two);
159              }
160  
161              Id = model.Id;
162              DisplayName = model.DisplayName;
163              Icon = new(model.Icon);
164              Icon.InitializeProperties();
165              displayInfoInitialized = true;
166  
167              // Update cached display name
168              if (_commandProviderCache is not null && Extension?.ExtensionUniqueId is not null)
169              {
170                  _commandProviderCache.Memorize(Extension.ExtensionUniqueId, new CommandProviderCacheItem(model.DisplayName));
171              }
172  
173              // Note: explicitly not InitializeProperties()ing the settings here. If
174              // we do that, then we'd regress GH #38321
175              Settings = new(model.Settings, this, _taskScheduler);
176  
177              // We do need to explicitly initialize commands though
178              InitializeCommands(commands, fallbacks, serviceProvider, pageContext);
179  
180              Logger.LogDebug($"Loaded commands from {DisplayName} ({ProviderId})");
181          }
182          catch (Exception e)
183          {
184              Logger.LogError("Failed to load commands from extension");
185              Logger.LogError($"Extension was {Extension!.PackageFamilyName}");
186              Logger.LogError(e.ToString());
187  
188              if (!displayInfoInitialized)
189              {
190                  RecallFromCache();
191              }
192          }
193      }
194  
195      private void RecallFromCache()
196      {
197          var cached = _commandProviderCache?.Recall(ProviderId);
198          if (cached is not null)
199          {
200              DisplayName = cached.DisplayName;
201          }
202  
203          if (string.IsNullOrWhiteSpace(DisplayName))
204          {
205              DisplayName = Extension?.PackageDisplayName ?? Extension?.PackageFamilyName ?? ProviderId;
206          }
207      }
208  
209      private void InitializeCommands(ICommandItem[] commands, IFallbackCommandItem[] fallbacks, IServiceProvider serviceProvider, WeakReference<IPageContext> pageContext)
210      {
211          var settings = serviceProvider.GetService<SettingsModel>()!;
212          var providerSettings = GetProviderSettings(settings);
213  
214          var makeAndAdd = (ICommandItem? i, bool fallback) =>
215          {
216              CommandItemViewModel commandItemViewModel = new(new(i), pageContext);
217              TopLevelViewModel topLevelViewModel = new(commandItemViewModel, fallback, ExtensionHost, ProviderId, settings, providerSettings, serviceProvider, i);
218              topLevelViewModel.InitializeProperties();
219  
220              return topLevelViewModel;
221          };
222  
223          if (commands is not null)
224          {
225              TopLevelItems = commands
226                  .Select(c => makeAndAdd(c, false))
227                  .ToArray();
228          }
229  
230          if (fallbacks is not null)
231          {
232              FallbackItems = fallbacks
233                  .Select(c => makeAndAdd(c, true))
234                  .ToArray();
235          }
236      }
237  
238      private void UnsafePreCacheApiAdditions(ICommandProvider2 provider)
239      {
240          var apiExtensions = provider.GetApiExtensionStubs();
241          Logger.LogDebug($"Provider supports {apiExtensions.Length} extensions");
242          foreach (var a in apiExtensions)
243          {
244              if (a is IExtendedAttributesProvider command2)
245              {
246                  Logger.LogDebug($"{ProviderId}: Found an IExtendedAttributesProvider");
247              }
248          }
249      }
250  
251      public override bool Equals(object? obj) => obj is CommandProviderWrapper wrapper && isValid == wrapper.isValid;
252  
253      public override int GetHashCode() => _commandProvider.GetHashCode();
254  
255      private void CommandProvider_ItemsChanged(object sender, IItemsChangedEventArgs args) =>
256  
257          // We don't want to handle this ourselves - we want the
258          // TopLevelCommandManager to know about this, so they can remove
259          // our old commands from their own list.
260          //
261          // In handling this, a call will be made to `LoadTopLevelCommands` to
262          // retrieve the new items.
263          this.CommandsChanged?.Invoke(this, args);
264  }