/ src / runner / auto_start_helper.cpp
auto_start_helper.cpp
  1  #include "pch.h"
  2  #include "auto_start_helper.h"
  3  
  4  #include <Lmcons.h>
  5  
  6  #include <comdef.h>
  7  #include <taskschd.h>
  8  #include <common/logger/logger.h>
  9  
 10  // Helper macros from wix.
 11  #define ExitOnFailure(x, s, ...) \
 12      if (FAILED(x))               \
 13      {                            \
 14          Logger::error(s, ##__VA_ARGS__); \
 15          goto LExit;              \
 16      }
 17  #define ExitWithLastError(x, s, ...)       \
 18      {                                      \
 19          DWORD util_err = ::GetLastError(); \
 20          x = HRESULT_FROM_WIN32(util_err);  \
 21          if (!FAILED(x))                    \
 22          {                                  \
 23              x = E_FAIL;                    \
 24          }                                  \
 25          Logger::error(s, ##__VA_ARGS__);   \
 26          goto LExit;                        \
 27      }
 28  #define ExitFunction() \
 29      {                  \
 30          goto LExit;    \
 31      }
 32  
 33  const DWORD USERNAME_DOMAIN_LEN = DNLEN + UNLEN + 2; // Domain Name + '\' + User Name + '\0'
 34  const DWORD USERNAME_LEN = UNLEN + 1; // User Name + '\0'
 35  
 36  bool create_auto_start_task_for_this_user(bool runElevated)
 37  {
 38      HRESULT hr = S_OK;
 39  
 40      WCHAR username_domain[USERNAME_DOMAIN_LEN];
 41      WCHAR username[USERNAME_LEN];
 42  
 43      std::wstring wstrTaskName;
 44  
 45      ITaskService* pService = NULL;
 46      ITaskFolder* pTaskFolder = NULL;
 47      ITaskDefinition* pTask = NULL;
 48      IRegistrationInfo* pRegInfo = NULL;
 49      ITaskSettings* pSettings = NULL;
 50      ITriggerCollection* pTriggerCollection = NULL;
 51      IRegisteredTask* pRegisteredTask = NULL;
 52  
 53      // ------------------------------------------------------
 54      // Get the Domain/Username for the trigger.
 55      if (!GetEnvironmentVariable(L"USERNAME", username, USERNAME_LEN))
 56      {
 57          ExitWithLastError(hr, "Getting username failed: {:x}", hr);
 58      }
 59      if (!GetEnvironmentVariable(L"USERDOMAIN", username_domain, USERNAME_DOMAIN_LEN))
 60      {
 61          ExitWithLastError(hr, "Getting the user's domain failed: {:x}", hr);
 62      }
 63      wcscat_s(username_domain, L"\\");
 64      wcscat_s(username_domain, username);
 65  
 66      // Task Name.
 67      wstrTaskName = L"Autorun for ";
 68      wstrTaskName += username;
 69  
 70      // Get the executable path passed to the custom action.
 71      WCHAR wszExecutablePath[MAX_PATH];
 72      GetModuleFileName(NULL, wszExecutablePath, MAX_PATH);
 73  
 74      // ------------------------------------------------------
 75      // Create an instance of the Task Service.
 76      hr = CoCreateInstance(CLSID_TaskScheduler,
 77                            NULL,
 78                            CLSCTX_INPROC_SERVER,
 79                            IID_ITaskService,
 80                            reinterpret_cast<void**>(&pService));
 81      ExitOnFailure(hr, "Failed to create an instance of ITaskService: {:x}", hr);
 82  
 83      // Connect to the task service.
 84      hr = pService->Connect(_variant_t(), _variant_t(), _variant_t(), _variant_t());
 85      ExitOnFailure(hr, "ITaskService::Connect failed: {:x}", hr);
 86  
 87      // ------------------------------------------------------
 88      // Get the PowerToys task folder. Creates it if it doesn't exist.
 89      hr = pService->GetFolder(_bstr_t(L"\\PowerToys"), &pTaskFolder);
 90      if (FAILED(hr))
 91      {
 92          // Folder doesn't exist. Get the Root folder and create the PowerToys subfolder.
 93          ITaskFolder* pRootFolder = NULL;
 94          hr = pService->GetFolder(_bstr_t(L"\\"), &pRootFolder);
 95          ExitOnFailure(hr, "Cannot get Root Folder pointer: {:x}", hr);
 96          hr = pRootFolder->CreateFolder(_bstr_t(L"\\PowerToys"), _variant_t(L""), &pTaskFolder);
 97          if (FAILED(hr))
 98          {
 99              pRootFolder->Release();
100              ExitOnFailure(hr, "Cannot create PowerToys task folder: {:x}", hr);
101          }
102      }
103  
104      // If the task exists, just enable it.
105      {
106          IRegisteredTask* pExistingRegisteredTask = NULL;
107          hr = pTaskFolder->GetTask(_bstr_t(wstrTaskName.c_str()), &pExistingRegisteredTask);
108          if (SUCCEEDED(hr))
109          {
110              // Task exists, try enabling it.
111              hr = pExistingRegisteredTask->put_Enabled(VARIANT_TRUE);
112              pExistingRegisteredTask->Release();
113              if (SUCCEEDED(hr))
114              {
115                  // Function enable. Sounds like a success.
116                  ExitFunction();
117              }
118          }
119      }
120  
121      // Create the task builder object to create the task.
122      hr = pService->NewTask(0, &pTask);
123      ExitOnFailure(hr, "Failed to create a task definition: {:x}", hr);
124  
125      // ------------------------------------------------------
126      // Get the registration info for setting the identification.
127      hr = pTask->get_RegistrationInfo(&pRegInfo);
128      ExitOnFailure(hr, "Cannot get identification pointer: {:x}", hr);
129      hr = pRegInfo->put_Author(_bstr_t(username_domain));
130      ExitOnFailure(hr, "Cannot put identification info: {:x}", hr);
131  
132      // ------------------------------------------------------
133      // Create the settings for the task
134      hr = pTask->get_Settings(&pSettings);
135      ExitOnFailure(hr, "Cannot get settings pointer: {:x}", hr);
136  
137      hr = pSettings->put_StartWhenAvailable(VARIANT_FALSE);
138      ExitOnFailure(hr, "Cannot put_StartWhenAvailable setting info: {:x}", hr);
139      hr = pSettings->put_StopIfGoingOnBatteries(VARIANT_FALSE);
140      ExitOnFailure(hr, "Cannot put_StopIfGoingOnBatteries setting info: {:x}", hr);
141      hr = pSettings->put_ExecutionTimeLimit(_bstr_t(L"PT0S")); //Unlimited
142      ExitOnFailure(hr, "Cannot put_ExecutionTimeLimit setting info: {:x}", hr);
143      hr = pSettings->put_DisallowStartIfOnBatteries(VARIANT_FALSE);
144      ExitOnFailure(hr, "Cannot put_DisallowStartIfOnBatteries setting info: {:x}", hr);
145      hr = pSettings->put_Priority(4);
146      ExitOnFailure(hr, "Cannot put_Priority setting info : {:x}", hr);
147  
148      // ------------------------------------------------------
149      // Get the trigger collection to insert the logon trigger.
150      hr = pTask->get_Triggers(&pTriggerCollection);
151      ExitOnFailure(hr, "Cannot get trigger collection: {:x}", hr);
152  
153      // Add the logon trigger to the task.
154      {
155          ITrigger* pTrigger = NULL;
156          ILogonTrigger* pLogonTrigger = NULL;
157          hr = pTriggerCollection->Create(TASK_TRIGGER_LOGON, &pTrigger);
158          ExitOnFailure(hr, "Cannot create the trigger: {:x}", hr);
159  
160          hr = pTrigger->QueryInterface(
161              IID_ILogonTrigger, reinterpret_cast<void**>(&pLogonTrigger));
162          pTrigger->Release();
163          ExitOnFailure(hr, "QueryInterface call failed for ILogonTrigger: {:x}", hr);
164  
165          hr = pLogonTrigger->put_Id(_bstr_t(L"Trigger1"));
166  
167          // Timing issues may make explorer not be started when the task runs.
168          // Add a little delay to mitigate this.
169          hr = pLogonTrigger->put_Delay(_bstr_t(L"PT03S"));
170  
171          // Define the user. The task will execute when the user logs on.
172          // The specified user must be a user on this computer.
173          hr = pLogonTrigger->put_UserId(_bstr_t(username_domain));
174          pLogonTrigger->Release();
175          ExitOnFailure(hr, "Cannot add user ID to logon trigger: {:x}", hr);
176      }
177  
178      // ------------------------------------------------------
179      // Add an Action to the task. This task will execute the path passed to this custom action.
180      {
181          IActionCollection* pActionCollection = NULL;
182          IAction* pAction = NULL;
183          IExecAction* pExecAction = NULL;
184  
185          // Get the task action collection pointer.
186          hr = pTask->get_Actions(&pActionCollection);
187          ExitOnFailure(hr, "Cannot get Task collection pointer: {:x}", hr);
188  
189          // Create the action, specifying that it is an executable action.
190          hr = pActionCollection->Create(TASK_ACTION_EXEC, &pAction);
191          pActionCollection->Release();
192          ExitOnFailure(hr, "Cannot create the action: {:x}", hr);
193  
194          // QI for the executable task pointer.
195          hr = pAction->QueryInterface(
196              IID_IExecAction, reinterpret_cast<void**>(&pExecAction));
197          pAction->Release();
198          ExitOnFailure(hr, "QueryInterface call failed for IExecAction: {:x}", hr);
199  
200          // Set the path of the executable to PowerToys (passed as CustomActionData).
201          hr = pExecAction->put_Path(_bstr_t(wszExecutablePath));
202          pExecAction->Release();
203          ExitOnFailure(hr, "Cannot set path of executable: {:x}", hr);
204      }
205  
206      // ------------------------------------------------------
207      // Create the principal for the task
208      {
209          IPrincipal* pPrincipal = NULL;
210          hr = pTask->get_Principal(&pPrincipal);
211          ExitOnFailure(hr, "Cannot get principal pointer: {:x}", hr);
212  
213          // Set up principal information:
214          hr = pPrincipal->put_Id(_bstr_t(L"Principal1"));
215  
216          hr = pPrincipal->put_UserId(_bstr_t(username_domain));
217  
218          hr = pPrincipal->put_LogonType(TASK_LOGON_INTERACTIVE_TOKEN);
219  
220          if (runElevated)
221          {
222              hr = pPrincipal->put_RunLevel(_TASK_RUNLEVEL::TASK_RUNLEVEL_HIGHEST);
223          }
224          else
225          {
226              hr = pPrincipal->put_RunLevel(_TASK_RUNLEVEL::TASK_RUNLEVEL_LUA);
227          }
228          pPrincipal->Release();
229          ExitOnFailure(hr, "Cannot put principal run level: {:x}", hr);
230      }
231      // ------------------------------------------------------
232      //  Save the task in the PowerToys folder.
233      {
234          _variant_t SDDL_FULL_ACCESS_FOR_EVERYONE = L"D:(A;;FA;;;WD)";
235          hr = pTaskFolder->RegisterTaskDefinition(
236              _bstr_t(wstrTaskName.c_str()),
237              pTask,
238              TASK_CREATE_OR_UPDATE,
239              _variant_t(username_domain),
240              _variant_t(),
241              TASK_LOGON_INTERACTIVE_TOKEN,
242              SDDL_FULL_ACCESS_FOR_EVERYONE,
243              &pRegisteredTask);
244          ExitOnFailure(hr, "Error saving the Task : {:x}", hr);
245      }
246  
247  LExit:
248      if (pService)
249          pService->Release();
250      if (pTaskFolder)
251          pTaskFolder->Release();
252      if (pTask)
253          pTask->Release();
254      if (pRegInfo)
255          pRegInfo->Release();
256      if (pSettings)
257          pSettings->Release();
258      if (pTriggerCollection)
259          pTriggerCollection->Release();
260      if (pRegisteredTask)
261          pRegisteredTask->Release();
262  
263      return (SUCCEEDED(hr));
264  }
265  
266  bool delete_auto_start_task_for_this_user()
267  {
268      HRESULT hr = S_OK;
269  
270      WCHAR username[USERNAME_LEN];
271      std::wstring wstrTaskName;
272  
273      ITaskService* pService = NULL;
274      ITaskFolder* pTaskFolder = NULL;
275  
276      // ------------------------------------------------------
277      // Get the Username for the task.
278      if (!GetEnvironmentVariable(L"USERNAME", username, USERNAME_LEN))
279      {
280          ExitWithLastError(hr, "Getting username failed: {:x}", hr);
281      }
282  
283      // Task Name.
284      wstrTaskName = L"Autorun for ";
285      wstrTaskName += username;
286  
287      // ------------------------------------------------------
288      // Create an instance of the Task Service.
289      hr = CoCreateInstance(CLSID_TaskScheduler,
290                            NULL,
291                            CLSCTX_INPROC_SERVER,
292                            IID_ITaskService,
293                            reinterpret_cast<void**>(&pService));
294      ExitOnFailure(hr, "Failed to create an instance of ITaskService: {:x}", hr);
295  
296      // Connect to the task service.
297      hr = pService->Connect(_variant_t(), _variant_t(), _variant_t(), _variant_t());
298      ExitOnFailure(hr, "ITaskService::Connect failed: {:x}", hr);
299  
300      // ------------------------------------------------------
301      // Get the PowerToys task folder.
302      hr = pService->GetFolder(_bstr_t(L"\\PowerToys"), &pTaskFolder);
303      if (FAILED(hr))
304      {
305          // Folder doesn't exist. No need to disable a non-existing task.
306          hr = S_OK;
307          ExitFunction();
308      }
309  
310      // ------------------------------------------------------
311      // If the task exists, disable.
312      {
313          IRegisteredTask* pExistingRegisteredTask = NULL;
314          hr = pTaskFolder->GetTask(_bstr_t(wstrTaskName.c_str()), &pExistingRegisteredTask);
315          if (SUCCEEDED(hr))
316          {
317              // Task exists, try disabling it.
318              hr = pTaskFolder->DeleteTask(_bstr_t(wstrTaskName.c_str()), 0);
319          }
320      }
321  
322  LExit:
323      if (pService)
324          pService->Release();
325      if (pTaskFolder)
326          pTaskFolder->Release();
327  
328      return (SUCCEEDED(hr));
329  }
330  
331  bool is_auto_start_task_active_for_this_user()
332  {
333      HRESULT hr = S_OK;
334  
335      WCHAR username[USERNAME_LEN];
336      std::wstring wstrTaskName;
337  
338      ITaskService* pService = NULL;
339      ITaskFolder* pTaskFolder = NULL;
340  
341      // ------------------------------------------------------
342      // Get the Username for the task.
343      if (!GetEnvironmentVariable(L"USERNAME", username, USERNAME_LEN))
344      {
345          ExitWithLastError(hr, "Getting username failed: {:x}", hr);
346      }
347  
348      // Task Name.
349      wstrTaskName = L"Autorun for ";
350      wstrTaskName += username;
351  
352      // ------------------------------------------------------
353      // Create an instance of the Task Service.
354      hr = CoCreateInstance(CLSID_TaskScheduler,
355                            NULL,
356                            CLSCTX_INPROC_SERVER,
357                            IID_ITaskService,
358                            reinterpret_cast<void**>(&pService));
359      ExitOnFailure(hr, "Failed to create an instance of ITaskService: {:x}", hr);
360  
361      // Connect to the task service.
362      hr = pService->Connect(_variant_t(), _variant_t(), _variant_t(), _variant_t());
363      ExitOnFailure(hr, "ITaskService::Connect failed: {:x}", hr);
364  
365      // ------------------------------------------------------
366      // Get the PowerToys task folder.
367      hr = pService->GetFolder(_bstr_t(L"\\PowerToys"), &pTaskFolder);
368      ExitOnFailure(hr, "ITaskFolder doesn't exist: {:x}", hr);
369  
370      // ------------------------------------------------------
371      // If the task exists, disable.
372      {
373          IRegisteredTask* pExistingRegisteredTask = NULL;
374          hr = pTaskFolder->GetTask(_bstr_t(wstrTaskName.c_str()), &pExistingRegisteredTask);
375          if (SUCCEEDED(hr))
376          {
377              // Task exists, get its value.
378              VARIANT_BOOL is_enabled;
379              hr = pExistingRegisteredTask->get_Enabled(&is_enabled);
380              pExistingRegisteredTask->Release();
381              if (SUCCEEDED(hr))
382              {
383                  // Got the value. Return it.
384                  hr = (is_enabled == VARIANT_TRUE) ? S_OK : E_FAIL; // Fake success or fail to return the value.
385                  ExitFunction();
386              }
387          }
388      }
389  
390  LExit:
391      if (pService)
392          pService->Release();
393      if (pTaskFolder)
394          pTaskFolder->Release();
395  
396      return (SUCCEEDED(hr));
397  }