ExtensionWrapper.cs
1 // Copyright (c) Microsoft Corporation 2 // The Microsoft Corporation licenses this file to you under the MIT license. 3 // See the LICENSE file in the project root for more information. 4 5 using System.Runtime.InteropServices; 6 using ManagedCommon; 7 using Microsoft.CmdPal.Core.Common.Services; 8 using Microsoft.CommandPalette.Extensions; 9 using Windows.ApplicationModel; 10 using Windows.ApplicationModel.AppExtensions; 11 using Windows.Win32; 12 using Windows.Win32.System.Com; 13 using WinRT; 14 15 namespace Microsoft.CmdPal.UI.ViewModels.Models; 16 17 public class ExtensionWrapper : IExtensionWrapper 18 { 19 private const int HResultRpcServerNotRunning = -2147023174; 20 21 private readonly string _appUserModelId; 22 private readonly string _extensionId; 23 24 private readonly Lock _lock = new(); 25 private readonly List<ProviderType> _providerTypes = []; 26 27 private readonly Dictionary<Type, ProviderType> _providerTypeMap = new() 28 { 29 [typeof(ICommandProvider)] = ProviderType.Commands, 30 }; 31 32 private IExtension? _extensionObject; 33 34 public ExtensionWrapper(AppExtension appExtension, string classId) 35 { 36 PackageDisplayName = appExtension.Package.DisplayName; 37 ExtensionDisplayName = appExtension.DisplayName; 38 PackageFullName = appExtension.Package.Id.FullName; 39 PackageFamilyName = appExtension.Package.Id.FamilyName; 40 ExtensionClassId = classId ?? throw new ArgumentNullException(nameof(classId)); 41 Publisher = appExtension.Package.PublisherDisplayName; 42 InstalledDate = appExtension.Package.InstalledDate; 43 Version = appExtension.Package.Id.Version; 44 _appUserModelId = appExtension.AppInfo.AppUserModelId; 45 _extensionId = appExtension.Id; 46 } 47 48 public string PackageDisplayName { get; } 49 50 public string ExtensionDisplayName { get; } 51 52 public string PackageFullName { get; } 53 54 public string PackageFamilyName { get; } 55 56 public string ExtensionClassId { get; } 57 58 public string Publisher { get; } 59 60 public DateTimeOffset InstalledDate { get; } 61 62 public PackageVersion Version { get; } 63 64 /// <summary> 65 /// Gets the unique id for this Dev Home extension. The unique id is a concatenation of: 66 /// <list type="number"> 67 /// <item>The AppUserModelId (AUMID) of the extension's application. The AUMID is the concatenation of the package 68 /// family name and the application id and uniquely identifies the application containing the extension within 69 /// the package.</item> 70 /// <item>The Extension Id. This is the unique identifier of the extension within the application.</item> 71 /// </list> 72 /// </summary> 73 public string ExtensionUniqueId => _appUserModelId + "!" + _extensionId; 74 75 public bool IsRunning() 76 { 77 if (_extensionObject is null) 78 { 79 return false; 80 } 81 82 try 83 { 84 _extensionObject.As<IInspectable>().GetRuntimeClassName(); 85 } 86 catch (COMException e) 87 { 88 if (e.ErrorCode == HResultRpcServerNotRunning) 89 { 90 return false; 91 } 92 93 throw; 94 } 95 96 return true; 97 } 98 99 public async Task StartExtensionAsync() 100 { 101 await Task.Run(() => 102 { 103 lock (_lock) 104 { 105 if (!IsRunning()) 106 { 107 Logger.LogDebug($"Starting {ExtensionDisplayName} ({ExtensionClassId})"); 108 109 unsafe 110 { 111 var extensionPtr = (void*)nint.Zero; 112 try 113 { 114 // -2147024809: E_INVALIDARG 115 // -2147467262: E_NOINTERFACE 116 // -2147024893: E_PATH_NOT_FOUND 117 var guid = typeof(IExtension).GUID; 118 119 var hr = PInvoke.CoCreateInstance(Guid.Parse(ExtensionClassId), null, CLSCTX.CLSCTX_LOCAL_SERVER, guid, out extensionPtr); 120 121 if (hr.Value == -2147024893) 122 { 123 Logger.LogError($"Failed to find {ExtensionDisplayName}: {hr}. It may have been uninstalled or deleted."); 124 125 // We don't really need to throw this exception. 126 // We'll just return out nothing. 127 return; 128 } 129 else if (hr.Value != 0) 130 { 131 Logger.LogError($"Failed to find {ExtensionDisplayName}: {hr.Value}"); 132 } 133 134 // Marshal.ThrowExceptionForHR(hr); 135 _extensionObject = MarshalInterface<IExtension>.FromAbi((nint)extensionPtr); 136 } 137 catch (Exception e) 138 { 139 Logger.LogDebug($"Failed to start {ExtensionDisplayName}. ex: {e.Message}"); 140 } 141 finally 142 { 143 if ((nint)extensionPtr != nint.Zero) 144 { 145 Marshal.Release((nint)extensionPtr); 146 } 147 } 148 } 149 } 150 } 151 }); 152 } 153 154 public void SignalDispose() 155 { 156 lock (_lock) 157 { 158 if (IsRunning()) 159 { 160 _extensionObject?.Dispose(); 161 } 162 163 _extensionObject = null; 164 } 165 } 166 167 public IExtension? GetExtensionObject() 168 { 169 lock (_lock) 170 { 171 return IsRunning() ? _extensionObject : null; 172 } 173 } 174 175 public async Task<T?> GetProviderAsync<T>() 176 where T : class 177 { 178 await StartExtensionAsync(); 179 180 return GetExtensionObject()?.GetProvider(_providerTypeMap[typeof(T)]) as T; 181 } 182 183 public async Task<IEnumerable<T>> GetListOfProvidersAsync<T>() 184 where T : class 185 { 186 await StartExtensionAsync(); 187 188 var supportedProviders = GetExtensionObject()?.GetProvider(_providerTypeMap[typeof(T)]); 189 if (supportedProviders is IEnumerable<T> multipleProvidersSupported) 190 { 191 return multipleProvidersSupported; 192 } 193 else if (supportedProviders is T singleProviderSupported) 194 { 195 return [singleProviderSupported]; 196 } 197 198 return Enumerable.Empty<T>(); 199 } 200 201 public void AddProviderType(ProviderType providerType) => _providerTypes.Add(providerType); 202 203 public bool HasProviderType(ProviderType providerType) => _providerTypes.Contains(providerType); 204 }