TopLevelCommandManager.cs
1 // Copyright (c) Microsoft Corporation 2 // The Microsoft Corporation licenses this file to you under the MIT license. 3 // See the LICENSE file in the project root for more information. 4 5 using System.Collections.Immutable; 6 using System.Collections.ObjectModel; 7 using System.Diagnostics; 8 using CommunityToolkit.Mvvm.ComponentModel; 9 using CommunityToolkit.Mvvm.Input; 10 using CommunityToolkit.Mvvm.Messaging; 11 using ManagedCommon; 12 using Microsoft.CmdPal.Core.Common.Helpers; 13 using Microsoft.CmdPal.Core.Common.Services; 14 using Microsoft.CmdPal.Core.ViewModels; 15 using Microsoft.CmdPal.UI.ViewModels.Messages; 16 using Microsoft.CmdPal.UI.ViewModels.Services; 17 using Microsoft.CommandPalette.Extensions; 18 using Microsoft.CommandPalette.Extensions.Toolkit; 19 using Microsoft.Extensions.DependencyInjection; 20 21 namespace Microsoft.CmdPal.UI.ViewModels; 22 23 public partial class TopLevelCommandManager : ObservableObject, 24 IRecipient<ReloadCommandsMessage>, 25 IPageContext, 26 IDisposable 27 { 28 private readonly IServiceProvider _serviceProvider; 29 private readonly ICommandProviderCache _commandProviderCache; 30 private readonly TaskScheduler _taskScheduler; 31 32 private readonly List<CommandProviderWrapper> _builtInCommands = []; 33 private readonly List<CommandProviderWrapper> _extensionCommandProviders = []; 34 private readonly Lock _commandProvidersLock = new(); 35 private readonly SupersedingAsyncGate _reloadCommandsGate; 36 37 TaskScheduler IPageContext.Scheduler => _taskScheduler; 38 39 public TopLevelCommandManager(IServiceProvider serviceProvider, ICommandProviderCache commandProviderCache) 40 { 41 _serviceProvider = serviceProvider; 42 _commandProviderCache = commandProviderCache; 43 _taskScheduler = _serviceProvider.GetService<TaskScheduler>()!; 44 WeakReferenceMessenger.Default.Register<ReloadCommandsMessage>(this); 45 _reloadCommandsGate = new(ReloadAllCommandsAsyncCore); 46 } 47 48 public ObservableCollection<TopLevelViewModel> TopLevelCommands { get; set; } = []; 49 50 [ObservableProperty] 51 public partial bool IsLoading { get; private set; } = true; 52 53 public IEnumerable<CommandProviderWrapper> CommandProviders 54 { 55 get 56 { 57 lock (_commandProvidersLock) 58 { 59 return _builtInCommands.Concat(_extensionCommandProviders).ToList(); 60 } 61 } 62 } 63 64 public async Task<bool> LoadBuiltinsAsync() 65 { 66 var s = new Stopwatch(); 67 s.Start(); 68 69 lock (_commandProvidersLock) 70 { 71 _builtInCommands.Clear(); 72 } 73 74 // Load built-In commands first. These are all in-proc, and 75 // owned by our ServiceProvider. 76 var builtInCommands = _serviceProvider.GetServices<ICommandProvider>(); 77 foreach (var provider in builtInCommands) 78 { 79 CommandProviderWrapper wrapper = new(provider, _taskScheduler); 80 lock (_commandProvidersLock) 81 { 82 _builtInCommands.Add(wrapper); 83 } 84 85 var commands = await LoadTopLevelCommandsFromProvider(wrapper); 86 lock (TopLevelCommands) 87 { 88 foreach (var c in commands) 89 { 90 TopLevelCommands.Add(c); 91 } 92 } 93 } 94 95 s.Stop(); 96 97 Logger.LogDebug($"Loading built-ins took {s.ElapsedMilliseconds}ms"); 98 99 return true; 100 } 101 102 // May be called from a background thread 103 private async Task<IEnumerable<TopLevelViewModel>> LoadTopLevelCommandsFromProvider(CommandProviderWrapper commandProvider) 104 { 105 WeakReference<IPageContext> weakSelf = new(this); 106 107 await commandProvider.LoadTopLevelCommands(_serviceProvider, weakSelf); 108 109 var commands = await Task.Factory.StartNew( 110 () => 111 { 112 List<TopLevelViewModel> commands = []; 113 foreach (var item in commandProvider.TopLevelItems) 114 { 115 commands.Add(item); 116 } 117 118 foreach (var item in commandProvider.FallbackItems) 119 { 120 if (item.IsEnabled) 121 { 122 commands.Add(item); 123 } 124 } 125 126 return commands; 127 }, 128 CancellationToken.None, 129 TaskCreationOptions.None, 130 _taskScheduler); 131 132 commandProvider.CommandsChanged -= CommandProvider_CommandsChanged; 133 commandProvider.CommandsChanged += CommandProvider_CommandsChanged; 134 135 return commands; 136 } 137 138 // By all accounts, we're already on a background thread (the COM call 139 // to handle the event shouldn't be on the main thread.). But just to 140 // be sure we don't block the caller, hop off this thread 141 private void CommandProvider_CommandsChanged(CommandProviderWrapper sender, IItemsChangedEventArgs args) => 142 _ = Task.Run(async () => await UpdateCommandsForProvider(sender, args)); 143 144 /// <summary> 145 /// Called when a command provider raises its ItemsChanged event. We'll 146 /// remove the old commands from the top-level list and try to put the new 147 /// ones in the same place in the list. 148 /// </summary> 149 /// <param name="sender">The provider who's commands changed</param> 150 /// <param name="args">the ItemsChangedEvent the provider raised</param> 151 /// <returns>an awaitable task</returns> 152 private async Task UpdateCommandsForProvider(CommandProviderWrapper sender, IItemsChangedEventArgs args) 153 { 154 WeakReference<IPageContext> weakSelf = new(this); 155 await sender.LoadTopLevelCommands(_serviceProvider, weakSelf); 156 157 List<TopLevelViewModel> newItems = [.. sender.TopLevelItems]; 158 foreach (var i in sender.FallbackItems) 159 { 160 if (i.IsEnabled) 161 { 162 newItems.Add(i); 163 } 164 } 165 166 // modify the TopLevelCommands under shared lock; event if we clone it, we don't want 167 // TopLevelCommands to get modified while we're working on it. Otherwise, we might 168 // out clone would be stale at the end of this method. 169 lock (TopLevelCommands) 170 { 171 // Work on a clone of the list, so that we can just do one atomic 172 // update to the actual observable list at the end 173 // TODO: just added a lock around all of this anyway, but keeping the clone 174 // while looking on some other ways to improve this; can be removed later. 175 List<TopLevelViewModel> clone = [.. TopLevelCommands]; 176 177 var startIndex = FindIndexForFirstProviderItem(clone, sender.ProviderId); 178 clone.RemoveAll(item => item.CommandProviderId == sender.ProviderId); 179 clone.InsertRange(startIndex, newItems); 180 181 ListHelpers.InPlaceUpdateList(TopLevelCommands, clone); 182 } 183 184 return; 185 186 static int FindIndexForFirstProviderItem(List<TopLevelViewModel> topLevelItems, string providerId) 187 { 188 // Tricky: all Commands from a single provider get added to the 189 // top-level list all together, in a row. So if we find just the first 190 // one, we can slice it out and insert the new ones there. 191 for (var i = 0; i < topLevelItems.Count; i++) 192 { 193 var wrapper = topLevelItems[i]; 194 try 195 { 196 if (providerId == wrapper.CommandProviderId) 197 { 198 return i; 199 } 200 } 201 catch 202 { 203 } 204 } 205 206 // If we didn't find any, then we just append the new commands to the end of the list. 207 return topLevelItems.Count; 208 } 209 } 210 211 public async Task ReloadAllCommandsAsync() 212 { 213 // gate ensures that the reload is serialized and if multiple calls 214 // request a reload, only the first and the last one will be executed. 215 // this should be superseded with a cancellable version. 216 await _reloadCommandsGate.ExecuteAsync(CancellationToken.None); 217 } 218 219 private async Task ReloadAllCommandsAsyncCore(CancellationToken cancellationToken) 220 { 221 IsLoading = true; 222 var extensionService = _serviceProvider.GetService<IExtensionService>()!; 223 await extensionService.SignalStopExtensionsAsync(); 224 225 lock (TopLevelCommands) 226 { 227 TopLevelCommands.Clear(); 228 } 229 230 await LoadBuiltinsAsync(); 231 _ = Task.Run(LoadExtensionsAsync); 232 } 233 234 // Load commands from our extensions. Called on a background thread. 235 // Currently, this 236 // * queries the package catalog, 237 // * starts all the extensions, 238 // * then fetches the top-level commands from them. 239 // TODO In the future, we'll probably abstract some of this away, to have 240 // separate extension tracking vs stub loading. 241 [RelayCommand] 242 public async Task<bool> LoadExtensionsAsync() 243 { 244 var extensionService = _serviceProvider.GetService<IExtensionService>()!; 245 246 extensionService.OnExtensionAdded -= ExtensionService_OnExtensionAdded; 247 extensionService.OnExtensionRemoved -= ExtensionService_OnExtensionRemoved; 248 249 var extensions = (await extensionService.GetInstalledExtensionsAsync()).ToImmutableList(); 250 lock (_commandProvidersLock) 251 { 252 _extensionCommandProviders.Clear(); 253 } 254 255 if (extensions is not null) 256 { 257 await StartExtensionsAndGetCommands(extensions); 258 } 259 260 extensionService.OnExtensionAdded += ExtensionService_OnExtensionAdded; 261 extensionService.OnExtensionRemoved += ExtensionService_OnExtensionRemoved; 262 263 IsLoading = false; 264 265 // Send on the current thread; receivers should marshal to UI if needed 266 WeakReferenceMessenger.Default.Send<ReloadFinishedMessage>(); 267 268 return true; 269 } 270 271 private void ExtensionService_OnExtensionAdded(IExtensionService sender, IEnumerable<IExtensionWrapper> extensions) 272 { 273 // When we get an extension install event, hop off to a BG thread 274 _ = Task.Run(async () => 275 { 276 // for each newly installed extension, start it and get commands 277 // from it. One single package might have more than one 278 // IExtensionWrapper in it. 279 await StartExtensionsAndGetCommands(extensions); 280 }); 281 } 282 283 private async Task StartExtensionsAndGetCommands(IEnumerable<IExtensionWrapper> extensions) 284 { 285 var timer = new Stopwatch(); 286 timer.Start(); 287 288 // Start all extensions in parallel 289 var startTasks = extensions.Select(StartExtensionWithTimeoutAsync); 290 291 // Wait for all extensions to start 292 var wrappers = (await Task.WhenAll(startTasks)).Where(wrapper => wrapper is not null).Select(w => w!).ToList(); 293 294 lock (_commandProvidersLock) 295 { 296 _extensionCommandProviders.AddRange(wrappers); 297 } 298 299 // Load the commands from the providers in parallel 300 var loadTasks = wrappers.Select(LoadCommandsWithTimeoutAsync); 301 302 var commandSets = (await Task.WhenAll(loadTasks)).Where(results => results is not null).Select(r => r!).ToList(); 303 304 lock (TopLevelCommands) 305 { 306 foreach (var commands in commandSets) 307 { 308 foreach (var c in commands) 309 { 310 TopLevelCommands.Add(c); 311 } 312 } 313 } 314 315 timer.Stop(); 316 Logger.LogDebug($"Loading extensions took {timer.ElapsedMilliseconds} ms"); 317 } 318 319 private async Task<CommandProviderWrapper?> StartExtensionWithTimeoutAsync(IExtensionWrapper extension) 320 { 321 Logger.LogDebug($"Starting {extension.PackageFullName}"); 322 try 323 { 324 await extension.StartExtensionAsync().WaitAsync(TimeSpan.FromSeconds(10)); 325 return new CommandProviderWrapper(extension, _taskScheduler, _commandProviderCache); 326 } 327 catch (Exception ex) 328 { 329 Logger.LogError($"Failed to start extension {extension.PackageFullName}: {ex}"); 330 return null; // Return null for failed extensions 331 } 332 } 333 334 private async Task<IEnumerable<TopLevelViewModel>?> LoadCommandsWithTimeoutAsync(CommandProviderWrapper wrapper) 335 { 336 try 337 { 338 return await LoadTopLevelCommandsFromProvider(wrapper!).WaitAsync(TimeSpan.FromSeconds(10)); 339 } 340 catch (TimeoutException) 341 { 342 Logger.LogError($"Loading commands from {wrapper!.ExtensionHost?.Extension?.PackageFullName} timed out"); 343 } 344 catch (Exception ex) 345 { 346 Logger.LogError($"Failed to load commands for extension {wrapper!.ExtensionHost?.Extension?.PackageFullName}: {ex}"); 347 } 348 349 return null; 350 } 351 352 private void ExtensionService_OnExtensionRemoved(IExtensionService sender, IEnumerable<IExtensionWrapper> extensions) 353 { 354 // When we get an extension uninstall event, hop off to a BG thread 355 _ = Task.Run( 356 async () => 357 { 358 // Then find all the top-level commands that belonged to that extension 359 List<TopLevelViewModel> commandsToRemove = []; 360 lock (TopLevelCommands) 361 { 362 foreach (var extension in extensions) 363 { 364 foreach (var command in TopLevelCommands) 365 { 366 var host = command.ExtensionHost; 367 if (host?.Extension == extension) 368 { 369 commandsToRemove.Add(command); 370 } 371 } 372 } 373 } 374 375 // Then back on the UI thread (remember, TopLevelCommands is 376 // Observable, so you can't touch it on the BG thread)... 377 await Task.Factory.StartNew( 378 () => 379 { 380 // ... remove all the deleted commands. 381 lock (TopLevelCommands) 382 { 383 if (commandsToRemove.Count != 0) 384 { 385 foreach (var deleted in commandsToRemove) 386 { 387 TopLevelCommands.Remove(deleted); 388 } 389 } 390 } 391 }, 392 CancellationToken.None, 393 TaskCreationOptions.None, 394 _taskScheduler); 395 }); 396 } 397 398 public TopLevelViewModel? LookupCommand(string id) 399 { 400 lock (TopLevelCommands) 401 { 402 foreach (var command in TopLevelCommands) 403 { 404 if (command.Id == id) 405 { 406 return command; 407 } 408 } 409 } 410 411 return null; 412 } 413 414 public void Receive(ReloadCommandsMessage message) => 415 ReloadAllCommandsAsync().ConfigureAwait(false); 416 417 void IPageContext.ShowException(Exception ex, string? extensionHint) 418 { 419 var message = DiagnosticsHelper.BuildExceptionMessage(ex, extensionHint ?? "TopLevelCommandManager"); 420 CommandPaletteHost.Instance.Log(message); 421 } 422 423 internal bool IsProviderActive(string id) 424 { 425 lock (_commandProvidersLock) 426 { 427 return _builtInCommands.Any(wrapper => wrapper.Id == id && wrapper.IsActive) 428 || _extensionCommandProviders.Any(wrapper => wrapper.Id == id && wrapper.IsActive); 429 } 430 } 431 432 public void Dispose() 433 { 434 _reloadCommandsGate.Dispose(); 435 GC.SuppressFinalize(this); 436 } 437 }