/ src / common / utils / package.h
package.h
  1  #pragma once
  2  
  3  #include <Windows.h>
  4  
  5  #include <algorithm>
  6  #include <appxpackaging.h>
  7  #include <exception>
  8  #include <filesystem>
  9  #include <regex>
 10  #include <string>
 11  #include <optional>
 12  #include <Shlwapi.h>
 13  #include <wrl/client.h>
 14  
 15  #include <winrt/Windows.ApplicationModel.h>
 16  #include <winrt/Windows.Foundation.h>
 17  #include <winrt/Windows.Management.Deployment.h>
 18  
 19  #include "../logger/logger.h"
 20  #include "../version/version.h"
 21  
 22  namespace package
 23  {
 24      using winrt::Windows::ApplicationModel::Package;
 25      using winrt::Windows::Foundation::IAsyncOperationWithProgress;
 26      using winrt::Windows::Foundation::AsyncStatus;
 27      using winrt::Windows::Foundation::Uri;
 28      using winrt::Windows::Foundation::Collections::IVector;
 29      using winrt::Windows::Management::Deployment::AddPackageOptions;
 30      using winrt::Windows::Management::Deployment::DeploymentOptions;
 31      using winrt::Windows::Management::Deployment::DeploymentProgress;
 32      using winrt::Windows::Management::Deployment::DeploymentResult;
 33      using winrt::Windows::Management::Deployment::PackageManager;
 34      using Microsoft::WRL::ComPtr;
 35  
 36      inline BOOL IsWin11OrGreater()
 37      {
 38          OSVERSIONINFOEX osvi{};
 39          DWORDLONG dwlConditionMask = 0;
 40          byte op = VER_GREATER_EQUAL;
 41  
 42          // Initialize the OSVERSIONINFOEX structure.
 43          osvi.dwOSVersionInfoSize = sizeof(OSVERSIONINFOEX);
 44          osvi.dwMajorVersion = HIBYTE(_WIN32_WINNT_WINTHRESHOLD);
 45          osvi.dwMinorVersion = LOBYTE(_WIN32_WINNT_WINTHRESHOLD);
 46          // Windows 11 build number
 47          osvi.dwBuildNumber = 22000;
 48  
 49          // Initialize the condition mask.
 50          VER_SET_CONDITION(dwlConditionMask, VER_MAJORVERSION, op);
 51          VER_SET_CONDITION(dwlConditionMask, VER_MINORVERSION, op);
 52          VER_SET_CONDITION(dwlConditionMask, VER_BUILDNUMBER, op);
 53  
 54          // Perform the test.
 55          return VerifyVersionInfo(
 56              &osvi,
 57              VER_MAJORVERSION | VER_MINORVERSION | VER_BUILDNUMBER,
 58              dwlConditionMask);
 59      }
 60  
 61      struct PACKAGE_VERSION
 62      {
 63          UINT16 Major;
 64          UINT16 Minor;
 65          UINT16 Build;
 66          UINT16 Revision;
 67      };
 68  
 69      class ComInitializer
 70      {
 71      public:
 72          explicit ComInitializer(DWORD coInitFlags = COINIT_MULTITHREADED) :
 73              _initialized(false)
 74          {
 75              const HRESULT hr = CoInitializeEx(nullptr, coInitFlags);
 76              _initialized = SUCCEEDED(hr);
 77          }
 78  
 79          ~ComInitializer()
 80          {
 81              if (_initialized)
 82              {
 83                  CoUninitialize();
 84              }
 85          }
 86  
 87          bool Succeeded() const { return _initialized; }
 88  
 89      private:
 90          bool _initialized;
 91      };
 92  
 93      inline bool GetPackageNameAndVersionFromAppx(
 94          const std::wstring& appxPath,
 95          std::wstring& outName,
 96          PACKAGE_VERSION& outVersion)
 97      {
 98          try
 99          {
100              ComInitializer comInit;
101              if (!comInit.Succeeded())
102              {
103                  Logger::error(L"COM initialization failed.");
104                  return false;
105              }
106  
107              ComPtr<IAppxFactory> factory;
108              ComPtr<IStream> stream;
109              ComPtr<IAppxPackageReader> reader;
110              ComPtr<IAppxManifestReader> manifest;
111              ComPtr<IAppxManifestPackageId> packageId;
112  
113              HRESULT hr = CoCreateInstance(__uuidof(AppxFactory), nullptr, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&factory));
114              if (FAILED(hr))
115                  return false;
116  
117              hr = SHCreateStreamOnFileEx(appxPath.c_str(), STGM_READ | STGM_SHARE_DENY_WRITE, FILE_ATTRIBUTE_NORMAL, FALSE, nullptr, &stream);
118              if (FAILED(hr))
119                  return false;
120  
121              hr = factory->CreatePackageReader(stream.Get(), &reader);
122              if (FAILED(hr))
123                  return false;
124  
125              hr = reader->GetManifest(&manifest);
126              if (FAILED(hr))
127                  return false;
128  
129              hr = manifest->GetPackageId(&packageId);
130              if (FAILED(hr))
131                  return false;
132  
133              LPWSTR name = nullptr;
134              hr = packageId->GetName(&name);
135              if (FAILED(hr))
136                  return false;
137  
138              UINT64 version = 0;
139              hr = packageId->GetVersion(&version);
140              if (FAILED(hr))
141                  return false;
142  
143              outName = std::wstring(name);
144              CoTaskMemFree(name);
145  
146              outVersion.Major = static_cast<UINT16>((version >> 48) & 0xFFFF);
147              outVersion.Minor = static_cast<UINT16>((version >> 32) & 0xFFFF);
148              outVersion.Build = static_cast<UINT16>((version >> 16) & 0xFFFF);
149              outVersion.Revision = static_cast<UINT16>(version & 0xFFFF);
150  
151              Logger::info(L"Package name: {}, version: {}.{}.{}.{}, appxPath: {}",
152                           outName,
153                           outVersion.Major,
154                           outVersion.Minor,
155                           outVersion.Build,
156                           outVersion.Revision,
157                           appxPath);
158  
159              return true;
160          }
161          catch (const std::exception& ex)
162          {
163              Logger::error(L"Standard exception: {}", winrt::to_hstring(ex.what()));
164              return false;
165          }
166          catch (...)
167          {
168              Logger::error(L"Unknown or non-standard exception occurred.");
169              return false;
170          }
171      }
172  
173      inline std::optional<Package> GetRegisteredPackage(std::wstring packageDisplayName, bool checkVersion)
174      {
175          PackageManager packageManager;
176  
177          for (const auto& package : packageManager.FindPackagesForUser({}))
178          {
179              const auto& packageFullName = std::wstring{ package.Id().FullName() };
180              const auto& packageVersion = package.Id().Version();
181  
182              if (packageFullName.contains(packageDisplayName))
183              {
184                  // If checkVersion is true, verify if the package has the same version as PowerToys.
185                  if ((!checkVersion) || (packageVersion.Major == VERSION_MAJOR && packageVersion.Minor == VERSION_MINOR && packageVersion.Revision == VERSION_REVISION))
186                  {
187                      return { package };
188                  }
189              }
190          }
191  
192          return {};
193      }
194  
195      inline bool IsPackageRegisteredWithPowerToysVersion(std::wstring packageDisplayName)
196      {
197          return GetRegisteredPackage(packageDisplayName, true).has_value();
198      }
199  
200      inline bool RegisterSparsePackage(const std::wstring& externalLocation, const std::wstring& sparsePkgPath)
201      {
202          try
203          {
204              Uri externalUri{ externalLocation };
205              Uri packageUri{ sparsePkgPath };
206  
207              PackageManager packageManager;
208  
209              // Declare use of an external location
210              AddPackageOptions options;
211              options.ExternalLocationUri(externalUri);
212              options.ForceUpdateFromAnyVersion(true);
213  
214              IAsyncOperationWithProgress<DeploymentResult, DeploymentProgress> deploymentOperation = packageManager.AddPackageByUriAsync(packageUri, options);
215              deploymentOperation.get();
216  
217              // Check the status of the operation
218              if (deploymentOperation.Status() == AsyncStatus::Error)
219              {
220                  auto deploymentResult{ deploymentOperation.GetResults() };
221                  auto errorCode = deploymentOperation.ErrorCode();
222                  auto errorText = deploymentResult.ErrorText();
223  
224                  Logger::error(L"Register {} package failed. ErrorCode: {}, ErrorText: {}", sparsePkgPath, std::to_wstring(errorCode), errorText);
225                  return false;
226              }
227              else if (deploymentOperation.Status() == AsyncStatus::Canceled)
228              {
229                  Logger::error(L"Register {} package canceled.", sparsePkgPath);
230                  return false;
231              }
232              else if (deploymentOperation.Status() == AsyncStatus::Completed)
233              {
234                  Logger::info(L"Register {} package completed.", sparsePkgPath);
235              }
236              else
237              {
238                  Logger::debug(L"Register {} package started.", sparsePkgPath);
239              }
240  
241              return true;
242          }
243          catch (std::exception& e)
244          {
245              Logger::error("Exception thrown while trying to register package: {}", e.what());
246  
247              return false;
248          }
249      }
250  
251      inline bool UnRegisterPackage(const std::wstring& pkgDisplayName)
252      {
253          try
254          {
255              PackageManager packageManager;
256              const static auto packages = packageManager.FindPackagesForUser({});
257  
258              for (auto const& package : packages)
259              {
260                  const auto& packageFullName = std::wstring{ package.Id().FullName() };
261  
262                  if (packageFullName.contains(pkgDisplayName))
263                  {
264                      auto deploymentOperation{ packageManager.RemovePackageAsync(packageFullName) };
265                      deploymentOperation.get();
266  
267                      // Check the status of the operation
268                      if (deploymentOperation.Status() == AsyncStatus::Error)
269                      {
270                          auto deploymentResult{ deploymentOperation.GetResults() };
271                          auto errorCode = deploymentOperation.ErrorCode();
272                          auto errorText = deploymentResult.ErrorText();
273  
274                          Logger::error(L"Unregister {} package failed. ErrorCode: {}, ErrorText: {}", packageFullName, std::to_wstring(errorCode), errorText);
275                      }
276                      else if (deploymentOperation.Status() == AsyncStatus::Canceled)
277                      {
278                          Logger::error(L"Unregister {} package canceled.", packageFullName);
279                      }
280                      else if (deploymentOperation.Status() == AsyncStatus::Completed)
281                      {
282                          Logger::info(L"Unregister {} package completed.", packageFullName);
283                      }
284                      else
285                      {
286                          Logger::debug(L"Unregister {} package started.", packageFullName);
287                      }
288  
289                      break;
290                  }
291              }
292          }
293          catch (std::exception& e)
294          {
295              Logger::error("Exception thrown while trying to unregister package: {}", e.what());
296              return false;
297          }
298  
299          return true;
300      }
301  
302      inline std::vector<std::wstring> FindMsixFile(const std::wstring& directoryPath, bool recursive)
303      {
304          if (directoryPath.empty())
305          {
306              return {};
307          }
308  
309          if (!std::filesystem::exists(directoryPath))
310          {
311              Logger::error(L"The directory '" + directoryPath + L"' does not exist.");
312              return {};
313          }
314  
315          const std::regex pattern(R"(^.+\.(appx|msix|msixbundle)$)", std::regex_constants::icase);
316          std::vector<std::wstring> matchedFiles;
317  
318          try
319          {
320              if (recursive)
321              {
322                  for (const auto& entry : std::filesystem::recursive_directory_iterator(directoryPath))
323                  {
324                      if (entry.is_regular_file())
325                      {
326                          const auto& fileName = entry.path().filename().string();
327                          if (std::regex_match(fileName, pattern))
328                          {
329                              matchedFiles.push_back(entry.path());
330                          }
331                      }
332                  }
333              }
334              else
335              {
336                  for (const auto& entry : std::filesystem::directory_iterator(directoryPath))
337                  {
338                      if (entry.is_regular_file())
339                      {
340                          const auto& fileName = entry.path().filename().string();
341                          if (std::regex_match(fileName, pattern))
342                          {
343                              matchedFiles.push_back(entry.path());
344                          }
345                      }
346                  }
347              }
348  
349              // Sort by package version in descending order (newest first)
350              std::sort(matchedFiles.begin(), matchedFiles.end(), [](const std::wstring& a, const std::wstring& b) {
351                  std::wstring nameA, nameB;
352                  PACKAGE_VERSION versionA{}, versionB{};
353  
354                  bool gotA = GetPackageNameAndVersionFromAppx(a, nameA, versionA);
355                  bool gotB = GetPackageNameAndVersionFromAppx(b, nameB, versionB);
356  
357                  // Files that failed to parse go to the end
358                  if (!gotA)
359                      return false;
360                  if (!gotB)
361                      return true;
362  
363                  // Compare versions: Major, Minor, Build, Revision (descending)
364                  if (versionA.Major != versionB.Major)
365                      return versionA.Major > versionB.Major;
366                  if (versionA.Minor != versionB.Minor)
367                      return versionA.Minor > versionB.Minor;
368                  if (versionA.Build != versionB.Build)
369                      return versionA.Build > versionB.Build;
370                  return versionA.Revision > versionB.Revision;
371              });
372          }
373          catch (const std::exception& ex)
374          {
375              Logger::error("An error occurred while searching for MSIX files: " + std::string(ex.what()));
376          }
377  
378          return matchedFiles;
379      }
380  
381      inline bool IsPackageSatisfied(const std::wstring& appxPath)
382      {
383          std::wstring targetName;
384          PACKAGE_VERSION targetVersion{};
385  
386          if (!GetPackageNameAndVersionFromAppx(appxPath, targetName, targetVersion))
387          {
388              Logger::error(L"Failed to get package name and version from appx: " + appxPath);
389              return false;
390          }
391  
392          PackageManager pm;
393  
394          for (const auto& package : pm.FindPackagesForUser({}))
395          {
396              const auto& id = package.Id();
397              if (std::wstring(id.Name()) == targetName)
398              {
399                  const auto& version = id.Version();
400  
401                  if (version.Major > targetVersion.Major ||
402                      (version.Major == targetVersion.Major && version.Minor > targetVersion.Minor) ||
403                      (version.Major == targetVersion.Major && version.Minor == targetVersion.Minor && version.Build > targetVersion.Build) ||
404                      (version.Major == targetVersion.Major && version.Minor == targetVersion.Minor && version.Build == targetVersion.Build && version.Revision >= targetVersion.Revision))
405                  {
406                      Logger::info(
407                          L"Package {} is already satisfied with version {}.{}.{}.{}; target version {}.{}.{}.{}; appxPath: {}",
408                          id.Name(),
409                          version.Major,
410                          version.Minor,
411                          version.Build,
412                          version.Revision,
413                          targetVersion.Major,
414                          targetVersion.Minor,
415                          targetVersion.Build,
416                          targetVersion.Revision,
417                          appxPath);
418                      return true;
419                  }
420              }
421          }
422  
423          Logger::info(
424              L"Package {} is not satisfied. Target version: {}.{}.{}.{}; appxPath: {}",
425              targetName,
426              targetVersion.Major,
427              targetVersion.Minor,
428              targetVersion.Build,
429              targetVersion.Revision,
430              appxPath);
431          return false;
432      }
433  
434      inline bool RegisterPackage(std::wstring pkgPath, std::vector<std::wstring> dependencies)
435      {
436          try
437          {
438              Uri packageUri{ pkgPath };
439  
440              PackageManager packageManager;
441  
442              // Declare use of an external location
443              DeploymentOptions options = DeploymentOptions::ForceTargetApplicationShutdown;
444  
445              IVector<Uri> uris = winrt::single_threaded_vector<Uri>();
446              if (!dependencies.empty())
447              {
448                  for (const auto& dependency : dependencies)
449                  {
450                      try
451                      {
452                          if (IsPackageSatisfied(dependency))
453                          {
454                              Logger::info(L"Dependency already satisfied: {}", dependency);
455                          }
456                          else
457                          {
458                              uris.Append(Uri(dependency));
459                          }
460                      }
461                      catch (const winrt::hresult_error& ex)
462                      {
463                          Logger::error(L"Error creating Uri for dependency: %s", ex.message().c_str());
464                      }
465                  }
466              }
467  
468              IAsyncOperationWithProgress<DeploymentResult, DeploymentProgress> deploymentOperation = packageManager.AddPackageAsync(packageUri, uris, options);
469              deploymentOperation.get();
470  
471              // Check the status of the operation
472              if (deploymentOperation.Status() == AsyncStatus::Error)
473              {
474                  auto deploymentResult{ deploymentOperation.GetResults() };
475                  auto errorCode = deploymentOperation.ErrorCode();
476                  auto errorText = deploymentResult.ErrorText();
477  
478                  Logger::error(L"Register {} package failed. ErrorCode: {}, ErrorText: {}", pkgPath, std::to_wstring(errorCode), errorText);
479                  return false;
480              }
481              else if (deploymentOperation.Status() == AsyncStatus::Canceled)
482              {
483                  Logger::error(L"Register {} package canceled.", pkgPath);
484                  return false;
485              }
486              else if (deploymentOperation.Status() == AsyncStatus::Completed)
487              {
488                  Logger::info(L"Register {} package completed.", pkgPath);
489              }
490              else
491              {
492                  Logger::debug(L"Register {} package started.", pkgPath);
493              }
494          }
495          catch (std::exception& e)
496          {
497              Logger::error("Exception thrown while trying to register package: {}", e.what());
498  
499              return false;
500          }
501  
502          return true;
503      }
504  }