ExtensionWrapper.cs
  1  // Copyright (c) Microsoft Corporation
  2  // The Microsoft Corporation licenses this file to you under the MIT license.
  3  // See the LICENSE file in the project root for more information.
  4  
  5  using System.Runtime.InteropServices;
  6  using ManagedCommon;
  7  using Microsoft.CmdPal.Core.Common.Services;
  8  using Microsoft.CommandPalette.Extensions;
  9  using Windows.ApplicationModel;
 10  using Windows.ApplicationModel.AppExtensions;
 11  using Windows.Win32;
 12  using Windows.Win32.System.Com;
 13  using WinRT;
 14  
 15  namespace Microsoft.CmdPal.UI.ViewModels.Models;
 16  
 17  public class ExtensionWrapper : IExtensionWrapper
 18  {
 19      private const int HResultRpcServerNotRunning = -2147023174;
 20  
 21      private readonly string _appUserModelId;
 22      private readonly string _extensionId;
 23  
 24      private readonly Lock _lock = new();
 25      private readonly List<ProviderType> _providerTypes = [];
 26  
 27      private readonly Dictionary<Type, ProviderType> _providerTypeMap = new()
 28      {
 29          [typeof(ICommandProvider)] = ProviderType.Commands,
 30      };
 31  
 32      private IExtension? _extensionObject;
 33  
 34      public ExtensionWrapper(AppExtension appExtension, string classId)
 35      {
 36          PackageDisplayName = appExtension.Package.DisplayName;
 37          ExtensionDisplayName = appExtension.DisplayName;
 38          PackageFullName = appExtension.Package.Id.FullName;
 39          PackageFamilyName = appExtension.Package.Id.FamilyName;
 40          ExtensionClassId = classId ?? throw new ArgumentNullException(nameof(classId));
 41          Publisher = appExtension.Package.PublisherDisplayName;
 42          InstalledDate = appExtension.Package.InstalledDate;
 43          Version = appExtension.Package.Id.Version;
 44          _appUserModelId = appExtension.AppInfo.AppUserModelId;
 45          _extensionId = appExtension.Id;
 46      }
 47  
 48      public string PackageDisplayName { get; }
 49  
 50      public string ExtensionDisplayName { get; }
 51  
 52      public string PackageFullName { get; }
 53  
 54      public string PackageFamilyName { get; }
 55  
 56      public string ExtensionClassId { get; }
 57  
 58      public string Publisher { get; }
 59  
 60      public DateTimeOffset InstalledDate { get; }
 61  
 62      public PackageVersion Version { get; }
 63  
 64      /// <summary>
 65      /// Gets the unique id for this Dev Home extension. The unique id is a concatenation of:
 66      /// <list type="number">
 67      /// <item>The AppUserModelId (AUMID) of the extension's application. The AUMID is the concatenation of the package
 68      /// family name and the application id and uniquely identifies the application containing the extension within
 69      /// the package.</item>
 70      /// <item>The Extension Id. This is the unique identifier of the extension within the application.</item>
 71      /// </list>
 72      /// </summary>
 73      public string ExtensionUniqueId => _appUserModelId + "!" + _extensionId;
 74  
 75      public bool IsRunning()
 76      {
 77          if (_extensionObject is null)
 78          {
 79              return false;
 80          }
 81  
 82          try
 83          {
 84              _extensionObject.As<IInspectable>().GetRuntimeClassName();
 85          }
 86          catch (COMException e)
 87          {
 88              if (e.ErrorCode == HResultRpcServerNotRunning)
 89              {
 90                  return false;
 91              }
 92  
 93              throw;
 94          }
 95  
 96          return true;
 97      }
 98  
 99      public async Task StartExtensionAsync()
100      {
101          await Task.Run(() =>
102          {
103              lock (_lock)
104              {
105                  if (!IsRunning())
106                  {
107                      Logger.LogDebug($"Starting {ExtensionDisplayName} ({ExtensionClassId})");
108  
109                      unsafe
110                      {
111                          var extensionPtr = (void*)nint.Zero;
112                          try
113                          {
114                              // -2147024809: E_INVALIDARG
115                              // -2147467262: E_NOINTERFACE
116                              // -2147024893: E_PATH_NOT_FOUND
117                              var guid = typeof(IExtension).GUID;
118  
119                              var hr = PInvoke.CoCreateInstance(Guid.Parse(ExtensionClassId), null, CLSCTX.CLSCTX_LOCAL_SERVER, guid, out extensionPtr);
120  
121                              if (hr.Value == -2147024893)
122                              {
123                                  Logger.LogError($"Failed to find {ExtensionDisplayName}: {hr}. It may have been uninstalled or deleted.");
124  
125                                  // We don't really need to throw this exception.
126                                  // We'll just return out nothing.
127                                  return;
128                              }
129                              else if (hr.Value != 0)
130                              {
131                                  Logger.LogError($"Failed to find {ExtensionDisplayName}: {hr.Value}");
132                              }
133  
134                              // Marshal.ThrowExceptionForHR(hr);
135                              _extensionObject = MarshalInterface<IExtension>.FromAbi((nint)extensionPtr);
136                          }
137                          catch (Exception e)
138                          {
139                              Logger.LogDebug($"Failed to start {ExtensionDisplayName}. ex: {e.Message}");
140                          }
141                          finally
142                          {
143                              if ((nint)extensionPtr != nint.Zero)
144                              {
145                                  Marshal.Release((nint)extensionPtr);
146                              }
147                          }
148                      }
149                  }
150              }
151          });
152      }
153  
154      public void SignalDispose()
155      {
156          lock (_lock)
157          {
158              if (IsRunning())
159              {
160                  _extensionObject?.Dispose();
161              }
162  
163              _extensionObject = null;
164          }
165      }
166  
167      public IExtension? GetExtensionObject()
168      {
169          lock (_lock)
170          {
171              return IsRunning() ? _extensionObject : null;
172          }
173      }
174  
175      public async Task<T?> GetProviderAsync<T>()
176          where T : class
177      {
178          await StartExtensionAsync();
179  
180          return GetExtensionObject()?.GetProvider(_providerTypeMap[typeof(T)]) as T;
181      }
182  
183      public async Task<IEnumerable<T>> GetListOfProvidersAsync<T>()
184          where T : class
185      {
186          await StartExtensionAsync();
187  
188          var supportedProviders = GetExtensionObject()?.GetProvider(_providerTypeMap[typeof(T)]);
189          if (supportedProviders is IEnumerable<T> multipleProvidersSupported)
190          {
191              return multipleProvidersSupported;
192          }
193          else if (supportedProviders is T singleProviderSupported)
194          {
195              return [singleProviderSupported];
196          }
197  
198          return Enumerable.Empty<T>();
199      }
200  
201      public void AddProviderType(ProviderType providerType) => _providerTypes.Add(providerType);
202  
203      public bool HasProviderType(ProviderType providerType) => _providerTypes.Contains(providerType);
204  }