/ src / modules / cmdpal / Microsoft.CmdPal.UI.ViewModels / TopLevelCommandManager.cs
TopLevelCommandManager.cs
  1  // Copyright (c) Microsoft Corporation
  2  // The Microsoft Corporation licenses this file to you under the MIT license.
  3  // See the LICENSE file in the project root for more information.
  4  
  5  using System.Collections.Immutable;
  6  using System.Collections.ObjectModel;
  7  using System.Diagnostics;
  8  using CommunityToolkit.Mvvm.ComponentModel;
  9  using CommunityToolkit.Mvvm.Input;
 10  using CommunityToolkit.Mvvm.Messaging;
 11  using ManagedCommon;
 12  using Microsoft.CmdPal.Core.Common.Helpers;
 13  using Microsoft.CmdPal.Core.Common.Services;
 14  using Microsoft.CmdPal.Core.ViewModels;
 15  using Microsoft.CmdPal.UI.ViewModels.Messages;
 16  using Microsoft.CmdPal.UI.ViewModels.Services;
 17  using Microsoft.CommandPalette.Extensions;
 18  using Microsoft.CommandPalette.Extensions.Toolkit;
 19  using Microsoft.Extensions.DependencyInjection;
 20  
 21  namespace Microsoft.CmdPal.UI.ViewModels;
 22  
 23  public partial class TopLevelCommandManager : ObservableObject,
 24      IRecipient<ReloadCommandsMessage>,
 25      IPageContext,
 26      IDisposable
 27  {
 28      private readonly IServiceProvider _serviceProvider;
 29      private readonly ICommandProviderCache _commandProviderCache;
 30      private readonly TaskScheduler _taskScheduler;
 31  
 32      private readonly List<CommandProviderWrapper> _builtInCommands = [];
 33      private readonly List<CommandProviderWrapper> _extensionCommandProviders = [];
 34      private readonly Lock _commandProvidersLock = new();
 35      private readonly SupersedingAsyncGate _reloadCommandsGate;
 36  
 37      TaskScheduler IPageContext.Scheduler => _taskScheduler;
 38  
 39      public TopLevelCommandManager(IServiceProvider serviceProvider, ICommandProviderCache commandProviderCache)
 40      {
 41          _serviceProvider = serviceProvider;
 42          _commandProviderCache = commandProviderCache;
 43          _taskScheduler = _serviceProvider.GetService<TaskScheduler>()!;
 44          WeakReferenceMessenger.Default.Register<ReloadCommandsMessage>(this);
 45          _reloadCommandsGate = new(ReloadAllCommandsAsyncCore);
 46      }
 47  
 48      public ObservableCollection<TopLevelViewModel> TopLevelCommands { get; set; } = [];
 49  
 50      [ObservableProperty]
 51      public partial bool IsLoading { get; private set; } = true;
 52  
 53      public IEnumerable<CommandProviderWrapper> CommandProviders
 54      {
 55          get
 56          {
 57              lock (_commandProvidersLock)
 58              {
 59                  return _builtInCommands.Concat(_extensionCommandProviders).ToList();
 60              }
 61          }
 62      }
 63  
 64      public async Task<bool> LoadBuiltinsAsync()
 65      {
 66          var s = new Stopwatch();
 67          s.Start();
 68  
 69          lock (_commandProvidersLock)
 70          {
 71              _builtInCommands.Clear();
 72          }
 73  
 74          // Load built-In commands first. These are all in-proc, and
 75          // owned by our ServiceProvider.
 76          var builtInCommands = _serviceProvider.GetServices<ICommandProvider>();
 77          foreach (var provider in builtInCommands)
 78          {
 79              CommandProviderWrapper wrapper = new(provider, _taskScheduler);
 80              lock (_commandProvidersLock)
 81              {
 82                  _builtInCommands.Add(wrapper);
 83              }
 84  
 85              var commands = await LoadTopLevelCommandsFromProvider(wrapper);
 86              lock (TopLevelCommands)
 87              {
 88                  foreach (var c in commands)
 89                  {
 90                      TopLevelCommands.Add(c);
 91                  }
 92              }
 93          }
 94  
 95          s.Stop();
 96  
 97          Logger.LogDebug($"Loading built-ins took {s.ElapsedMilliseconds}ms");
 98  
 99          return true;
100      }
101  
102      // May be called from a background thread
103      private async Task<IEnumerable<TopLevelViewModel>> LoadTopLevelCommandsFromProvider(CommandProviderWrapper commandProvider)
104      {
105          WeakReference<IPageContext> weakSelf = new(this);
106  
107          await commandProvider.LoadTopLevelCommands(_serviceProvider, weakSelf);
108  
109          var commands = await Task.Factory.StartNew(
110              () =>
111              {
112                  List<TopLevelViewModel> commands = [];
113                  foreach (var item in commandProvider.TopLevelItems)
114                  {
115                      commands.Add(item);
116                  }
117  
118                  foreach (var item in commandProvider.FallbackItems)
119                  {
120                      if (item.IsEnabled)
121                      {
122                          commands.Add(item);
123                      }
124                  }
125  
126                  return commands;
127              },
128              CancellationToken.None,
129              TaskCreationOptions.None,
130              _taskScheduler);
131  
132          commandProvider.CommandsChanged -= CommandProvider_CommandsChanged;
133          commandProvider.CommandsChanged += CommandProvider_CommandsChanged;
134  
135          return commands;
136      }
137  
138      // By all accounts, we're already on a background thread (the COM call
139      // to handle the event shouldn't be on the main thread.). But just to
140      // be sure we don't block the caller, hop off this thread
141      private void CommandProvider_CommandsChanged(CommandProviderWrapper sender, IItemsChangedEventArgs args) =>
142          _ = Task.Run(async () => await UpdateCommandsForProvider(sender, args));
143  
144      /// <summary>
145      /// Called when a command provider raises its ItemsChanged event. We'll
146      /// remove the old commands from the top-level list and try to put the new
147      /// ones in the same place in the list.
148      /// </summary>
149      /// <param name="sender">The provider who's commands changed</param>
150      /// <param name="args">the ItemsChangedEvent the provider raised</param>
151      /// <returns>an awaitable task</returns>
152      private async Task UpdateCommandsForProvider(CommandProviderWrapper sender, IItemsChangedEventArgs args)
153      {
154          WeakReference<IPageContext> weakSelf = new(this);
155          await sender.LoadTopLevelCommands(_serviceProvider, weakSelf);
156  
157          List<TopLevelViewModel> newItems = [.. sender.TopLevelItems];
158          foreach (var i in sender.FallbackItems)
159          {
160              if (i.IsEnabled)
161              {
162                  newItems.Add(i);
163              }
164          }
165  
166          // modify the TopLevelCommands under shared lock; event if we clone it, we don't want
167          // TopLevelCommands to get modified while we're working on it. Otherwise, we might
168          // out clone would be stale at the end of this method.
169          lock (TopLevelCommands)
170          {
171              // Work on a clone of the list, so that we can just do one atomic
172              // update to the actual observable list at the end
173              // TODO: just added a lock around all of this anyway, but keeping the clone
174              // while looking on some other ways to improve this; can be removed later.
175              List<TopLevelViewModel> clone = [.. TopLevelCommands];
176  
177              var startIndex = FindIndexForFirstProviderItem(clone, sender.ProviderId);
178              clone.RemoveAll(item => item.CommandProviderId == sender.ProviderId);
179              clone.InsertRange(startIndex, newItems);
180  
181              ListHelpers.InPlaceUpdateList(TopLevelCommands, clone);
182          }
183  
184          return;
185  
186          static int FindIndexForFirstProviderItem(List<TopLevelViewModel> topLevelItems, string providerId)
187          {
188              // Tricky: all Commands from a single provider get added to the
189              // top-level list all together, in a row. So if we find just the first
190              // one, we can slice it out and insert the new ones there.
191              for (var i = 0; i < topLevelItems.Count; i++)
192              {
193                  var wrapper = topLevelItems[i];
194                  try
195                  {
196                      if (providerId == wrapper.CommandProviderId)
197                      {
198                          return i;
199                      }
200                  }
201                  catch
202                  {
203                  }
204              }
205  
206              // If we didn't find any, then we just append the new commands to the end of the list.
207              return topLevelItems.Count;
208          }
209      }
210  
211      public async Task ReloadAllCommandsAsync()
212      {
213          // gate ensures that the reload is serialized and if multiple calls
214          // request a reload, only the first and the last one will be executed.
215          // this should be superseded with a cancellable version.
216          await _reloadCommandsGate.ExecuteAsync(CancellationToken.None);
217      }
218  
219      private async Task ReloadAllCommandsAsyncCore(CancellationToken cancellationToken)
220      {
221          IsLoading = true;
222          var extensionService = _serviceProvider.GetService<IExtensionService>()!;
223          await extensionService.SignalStopExtensionsAsync();
224  
225          lock (TopLevelCommands)
226          {
227              TopLevelCommands.Clear();
228          }
229  
230          await LoadBuiltinsAsync();
231          _ = Task.Run(LoadExtensionsAsync);
232      }
233  
234      // Load commands from our extensions. Called on a background thread.
235      // Currently, this
236      // * queries the package catalog,
237      // * starts all the extensions,
238      // * then fetches the top-level commands from them.
239      // TODO In the future, we'll probably abstract some of this away, to have
240      // separate extension tracking vs stub loading.
241      [RelayCommand]
242      public async Task<bool> LoadExtensionsAsync()
243      {
244          var extensionService = _serviceProvider.GetService<IExtensionService>()!;
245  
246          extensionService.OnExtensionAdded -= ExtensionService_OnExtensionAdded;
247          extensionService.OnExtensionRemoved -= ExtensionService_OnExtensionRemoved;
248  
249          var extensions = (await extensionService.GetInstalledExtensionsAsync()).ToImmutableList();
250          lock (_commandProvidersLock)
251          {
252              _extensionCommandProviders.Clear();
253          }
254  
255          if (extensions is not null)
256          {
257              await StartExtensionsAndGetCommands(extensions);
258          }
259  
260          extensionService.OnExtensionAdded += ExtensionService_OnExtensionAdded;
261          extensionService.OnExtensionRemoved += ExtensionService_OnExtensionRemoved;
262  
263          IsLoading = false;
264  
265          // Send on the current thread; receivers should marshal to UI if needed
266          WeakReferenceMessenger.Default.Send<ReloadFinishedMessage>();
267  
268          return true;
269      }
270  
271      private void ExtensionService_OnExtensionAdded(IExtensionService sender, IEnumerable<IExtensionWrapper> extensions)
272      {
273          // When we get an extension install event, hop off to a BG thread
274          _ = Task.Run(async () =>
275          {
276              // for each newly installed extension, start it and get commands
277              // from it. One single package might have more than one
278              // IExtensionWrapper in it.
279              await StartExtensionsAndGetCommands(extensions);
280          });
281      }
282  
283      private async Task StartExtensionsAndGetCommands(IEnumerable<IExtensionWrapper> extensions)
284      {
285          var timer = new Stopwatch();
286          timer.Start();
287  
288          // Start all extensions in parallel
289          var startTasks = extensions.Select(StartExtensionWithTimeoutAsync);
290  
291          // Wait for all extensions to start
292          var wrappers = (await Task.WhenAll(startTasks)).Where(wrapper => wrapper is not null).Select(w => w!).ToList();
293  
294          lock (_commandProvidersLock)
295          {
296              _extensionCommandProviders.AddRange(wrappers);
297          }
298  
299          // Load the commands from the providers in parallel
300          var loadTasks = wrappers.Select(LoadCommandsWithTimeoutAsync);
301  
302          var commandSets = (await Task.WhenAll(loadTasks)).Where(results => results is not null).Select(r => r!).ToList();
303  
304          lock (TopLevelCommands)
305          {
306              foreach (var commands in commandSets)
307              {
308                  foreach (var c in commands)
309                  {
310                      TopLevelCommands.Add(c);
311                  }
312              }
313          }
314  
315          timer.Stop();
316          Logger.LogDebug($"Loading extensions took {timer.ElapsedMilliseconds} ms");
317      }
318  
319      private async Task<CommandProviderWrapper?> StartExtensionWithTimeoutAsync(IExtensionWrapper extension)
320      {
321          Logger.LogDebug($"Starting {extension.PackageFullName}");
322          try
323          {
324              await extension.StartExtensionAsync().WaitAsync(TimeSpan.FromSeconds(10));
325              return new CommandProviderWrapper(extension, _taskScheduler, _commandProviderCache);
326          }
327          catch (Exception ex)
328          {
329              Logger.LogError($"Failed to start extension {extension.PackageFullName}: {ex}");
330              return null; // Return null for failed extensions
331          }
332      }
333  
334      private async Task<IEnumerable<TopLevelViewModel>?> LoadCommandsWithTimeoutAsync(CommandProviderWrapper wrapper)
335      {
336          try
337          {
338              return await LoadTopLevelCommandsFromProvider(wrapper!).WaitAsync(TimeSpan.FromSeconds(10));
339          }
340          catch (TimeoutException)
341          {
342              Logger.LogError($"Loading commands from {wrapper!.ExtensionHost?.Extension?.PackageFullName} timed out");
343          }
344          catch (Exception ex)
345          {
346              Logger.LogError($"Failed to load commands for extension {wrapper!.ExtensionHost?.Extension?.PackageFullName}: {ex}");
347          }
348  
349          return null;
350      }
351  
352      private void ExtensionService_OnExtensionRemoved(IExtensionService sender, IEnumerable<IExtensionWrapper> extensions)
353      {
354          // When we get an extension uninstall event, hop off to a BG thread
355          _ = Task.Run(
356              async () =>
357              {
358                  // Then find all the top-level commands that belonged to that extension
359                  List<TopLevelViewModel> commandsToRemove = [];
360                  lock (TopLevelCommands)
361                  {
362                      foreach (var extension in extensions)
363                      {
364                          foreach (var command in TopLevelCommands)
365                          {
366                              var host = command.ExtensionHost;
367                              if (host?.Extension == extension)
368                              {
369                                  commandsToRemove.Add(command);
370                              }
371                          }
372                      }
373                  }
374  
375                  // Then back on the UI thread (remember, TopLevelCommands is
376                  // Observable, so you can't touch it on the BG thread)...
377                  await Task.Factory.StartNew(
378                  () =>
379                  {
380                      // ... remove all the deleted commands.
381                      lock (TopLevelCommands)
382                      {
383                          if (commandsToRemove.Count != 0)
384                          {
385                              foreach (var deleted in commandsToRemove)
386                              {
387                                  TopLevelCommands.Remove(deleted);
388                              }
389                          }
390                      }
391                  },
392                  CancellationToken.None,
393                  TaskCreationOptions.None,
394                  _taskScheduler);
395              });
396      }
397  
398      public TopLevelViewModel? LookupCommand(string id)
399      {
400          lock (TopLevelCommands)
401          {
402              foreach (var command in TopLevelCommands)
403              {
404                  if (command.Id == id)
405                  {
406                      return command;
407                  }
408              }
409          }
410  
411          return null;
412      }
413  
414      public void Receive(ReloadCommandsMessage message) =>
415          ReloadAllCommandsAsync().ConfigureAwait(false);
416  
417      void IPageContext.ShowException(Exception ex, string? extensionHint)
418      {
419          var message = DiagnosticsHelper.BuildExceptionMessage(ex, extensionHint ?? "TopLevelCommandManager");
420          CommandPaletteHost.Instance.Log(message);
421      }
422  
423      internal bool IsProviderActive(string id)
424      {
425          lock (_commandProvidersLock)
426          {
427              return _builtInCommands.Any(wrapper => wrapper.Id == id && wrapper.IsActive)
428                     || _extensionCommandProviders.Any(wrapper => wrapper.Id == id && wrapper.IsActive);
429          }
430      }
431  
432      public void Dispose()
433      {
434          _reloadCommandsGate.Dispose();
435          GC.SuppressFinalize(this);
436      }
437  }