/ src / runner / centralized_kb_hook.cpp
centralized_kb_hook.cpp
  1  #include "pch.h"
  2  #include "centralized_kb_hook.h"
  3  #include <common/debug_control.h>
  4  #include <common/utils/winapi_error.h>
  5  #include <common/logger/logger.h>
  6  #include <common/interop/shared_constants.h>
  7  
  8  namespace CentralizedKeyboardHook
  9  {
 10      struct HotkeyDescriptor
 11      {
 12          Hotkey hotkey;
 13          std::wstring moduleName;
 14          std::function<bool()> action;
 15  
 16          bool operator<(const HotkeyDescriptor& other) const
 17          {
 18              return hotkey < other.hotkey;
 19          };
 20      };
 21  
 22      std::multiset<HotkeyDescriptor> hotkeyDescriptors;
 23      std::mutex mutex;
 24      HHOOK hHook{};
 25  
 26      // To store information about handling pressed keys.
 27      struct PressedKeyDescriptor
 28      {
 29          DWORD virtualKey; // Virtual Key code of the key we're keeping track of.
 30          std::wstring moduleName;
 31          std::function<bool()> action;
 32          UINT_PTR idTimer; // Timer ID for calling SET_TIMER with.
 33          UINT millisecondsToPress; // How much time the key must be pressed.
 34          bool operator<(const PressedKeyDescriptor& other) const
 35          {
 36              // We'll use the virtual key as the real key, since looking for a hit with the key is done in the more time sensitive path (low level keyboard hook).
 37              return virtualKey < other.virtualKey;
 38          };
 39      };
 40      std::multiset<PressedKeyDescriptor> pressedKeyDescriptors;
 41      std::mutex pressedKeyMutex;
 42  
 43      // keep track of last pressed key, to detect repeated keys and if there are more keys pressed.
 44      const DWORD VK_DISABLED = CommonSharedConstants::VK_DISABLED;
 45      DWORD vkCodePressed = VK_DISABLED;
 46  
 47      // Save the runner window handle for registering timers.
 48      HWND runnerWindow;
 49  
 50      struct DestroyOnExit
 51      {
 52          ~DestroyOnExit()
 53          {
 54              Stop();
 55          }
 56      } destroyOnExitObj;
 57  
 58      // Handle the pressed key proc
 59      void PressedKeyTimerProc(
 60          HWND hwnd,
 61          UINT /*message*/,
 62          UINT_PTR idTimer,
 63          DWORD /*dwTime*/)
 64      {
 65          std::multiset<PressedKeyDescriptor> copy;
 66          {
 67              // Make a copy, to look for the action to call.
 68              std::unique_lock lock{ pressedKeyMutex };
 69              copy = pressedKeyDescriptors;
 70          }
 71          for (const auto& it : copy)
 72          {
 73              if (it.idTimer == idTimer)
 74              {
 75                  it.action();
 76              }
 77          }
 78  
 79          KillTimer(hwnd, idTimer);
 80      }
 81  
 82      LRESULT CALLBACK KeyboardHookProc(_In_ int nCode, _In_ WPARAM wParam, _In_ LPARAM lParam)
 83      {
 84          if (nCode < 0)
 85          {
 86              return CallNextHookEx(hHook, nCode, wParam, lParam);
 87          }
 88  
 89          const auto& keyPressInfo = *reinterpret_cast<KBDLLHOOKSTRUCT*>(lParam);
 90  
 91          if (keyPressInfo.dwExtraInfo == PowertoyModuleIface::CENTRALIZED_KEYBOARD_HOOK_DONT_TRIGGER_FLAG)
 92          {
 93              // The new keystroke was generated from one of our actions. We should pass it along.
 94              return CallNextHookEx(hHook, nCode, wParam, lParam);
 95          }
 96  
 97          // Check if the keys are pressed.
 98          if (!pressedKeyDescriptors.empty())
 99          {
100              bool wasKeyPressed = vkCodePressed != VK_DISABLED;
101              // Hold the lock for the shortest possible duration
102              if ((wParam == WM_KEYDOWN || wParam == WM_SYSKEYDOWN))
103              {
104                  if (!wasKeyPressed)
105                  {
106                      // If no key was pressed before, let's start a timer to take into account this new key.
107                      std::unique_lock lock{ pressedKeyMutex };
108                      PressedKeyDescriptor dummy{ .virtualKey = keyPressInfo.vkCode };
109                      auto [it, last] = pressedKeyDescriptors.equal_range(dummy);
110                      for (; it != last; ++it)
111                      {
112                          SetTimer(runnerWindow, it->idTimer, it->millisecondsToPress, PressedKeyTimerProc);
113                      }
114                  }
115                  else if (vkCodePressed != keyPressInfo.vkCode)
116                  {
117                      // If a different key was pressed, let's clear the timers we have started for the previous key.
118                      std::unique_lock lock{ pressedKeyMutex };
119                      PressedKeyDescriptor dummy{ .virtualKey = vkCodePressed };
120                      auto [it, last] = pressedKeyDescriptors.equal_range(dummy);
121                      for (; it != last; ++it)
122                      {
123                          KillTimer(runnerWindow, it->idTimer);
124                      }
125                  }
126                  vkCodePressed = keyPressInfo.vkCode;
127              }
128              if (wParam == WM_KEYUP || wParam == WM_SYSKEYUP)
129              {
130                  std::unique_lock lock{ pressedKeyMutex };
131                  PressedKeyDescriptor dummy{ .virtualKey = keyPressInfo.vkCode };
132                  auto [it, last] = pressedKeyDescriptors.equal_range(dummy);
133                  for (; it != last; ++it)
134                  {
135                      KillTimer(runnerWindow, it->idTimer);
136                  }
137                  vkCodePressed = 0x100;
138              }
139          }
140  
141          if ((wParam != WM_KEYDOWN) && (wParam != WM_SYSKEYDOWN))
142          {
143              return CallNextHookEx(hHook, nCode, wParam, lParam);
144          }
145  
146          Hotkey hotkey{
147              .win = (GetAsyncKeyState(VK_LWIN) & 0x8000) || (GetAsyncKeyState(VK_RWIN) & 0x8000),
148              .ctrl = static_cast<bool>(GetAsyncKeyState(VK_CONTROL) & 0x8000),
149              .shift = static_cast<bool>(GetAsyncKeyState(VK_SHIFT) & 0x8000),
150              .alt = static_cast<bool>(GetAsyncKeyState(VK_MENU) & 0x8000),
151              .key = static_cast<unsigned char>(keyPressInfo.vkCode)
152          };
153  
154          if (hotkey == Hotkey{})
155          {
156              return CallNextHookEx(hHook, nCode, wParam, lParam);
157          }
158  
159          std::function<bool()> action;
160          {
161              // Hold the lock for the shortest possible duration
162              std::unique_lock lock{ mutex };
163              HotkeyDescriptor dummy{ .hotkey = hotkey };
164              auto it = hotkeyDescriptors.find(dummy);
165              if (it != hotkeyDescriptors.end())
166              {
167                  action = it->action;
168              }
169          }
170  
171          if (action)
172          {
173              if (action())
174              {
175                  // After invoking the hotkey send a dummy key to prevent Start Menu from activating
176                  INPUT dummyEvent[1] = {};
177                  dummyEvent[0].type = INPUT_KEYBOARD;
178                  dummyEvent[0].ki.wVk = 0xFF;
179                  dummyEvent[0].ki.dwFlags = KEYEVENTF_KEYUP;
180                  SendInput(1, dummyEvent, sizeof(INPUT));
181  
182                  // Swallow the key press
183                  return 1;
184              }
185          }
186  
187          return CallNextHookEx(hHook, nCode, wParam, lParam);
188      }
189  
190      void SetHotkeyAction(const std::wstring& moduleName, const Hotkey& hotkey, std::function<bool()>&& action) noexcept
191      {
192          Logger::trace(L"Register hotkey action for {}", moduleName);
193          std::unique_lock lock{ mutex };
194          hotkeyDescriptors.insert({ .hotkey = hotkey, .moduleName = moduleName, .action = std::move(action) });
195      }
196  
197      void AddPressedKeyAction(const std::wstring& moduleName, const DWORD vk, const UINT milliseconds, std::function<bool()>&& action) noexcept
198      {
199          // Calculate a unique TimerID.
200          auto hash = std::hash<std::wstring>{}(moduleName); // Hash the module as the upper part of the timer ID.
201          const UINT upperId = hash & 0xFFFF;
202          const UINT lowerId = vk & 0xFFFF; // The key to press can be the lower ID.
203          const UINT timerId = upperId << 16 | lowerId;
204          std::unique_lock lock{ pressedKeyMutex };
205          pressedKeyDescriptors.insert({ .virtualKey = vk, .moduleName = moduleName, .action = std::move(action), .idTimer = timerId, .millisecondsToPress = milliseconds });
206      }
207  
208      void ClearModuleHotkeys(const std::wstring& moduleName) noexcept
209      {
210          Logger::trace(L"UnRegister hotkey action for {}", moduleName);
211          {
212              std::unique_lock lock{ mutex };
213              auto it = hotkeyDescriptors.begin();
214              while (it != hotkeyDescriptors.end())
215              {
216                  if (it->moduleName == moduleName)
217                  {
218                      it = hotkeyDescriptors.erase(it);
219                  }
220                  else
221                  {
222                      ++it;
223                  }
224              }
225          }
226          {
227              std::unique_lock lock{ pressedKeyMutex };
228              auto it = pressedKeyDescriptors.begin();
229              while (it != pressedKeyDescriptors.end())
230              {
231                  if (it->moduleName == moduleName)
232                  {
233                      it = pressedKeyDescriptors.erase(it);
234                  }
235                  else
236                  {
237                      ++it;
238                  }
239              }
240          }
241      }
242  
243      void Start() noexcept
244      {
245  #if defined(DISABLE_LOWLEVEL_HOOKS_WHEN_DEBUGGED)
246          const bool hook_disabled = IsDebuggerPresent();
247  #else
248          const bool hook_disabled = false;
249  #endif
250          if (!hook_disabled)
251          {
252              if (!hHook)
253              {
254                  hHook = SetWindowsHookExW(WH_KEYBOARD_LL, KeyboardHookProc, NULL, NULL);
255                  if (!hHook)
256                  {
257                      DWORD errorCode = GetLastError();
258                      show_last_error_message(L"SetWindowsHookEx", errorCode, L"centralized_kb_hook");
259                  }
260              }
261          }
262      }
263  
264      void Stop() noexcept
265      {
266          if (hHook && UnhookWindowsHookEx(hHook))
267          {
268              hHook = NULL;
269          }
270      }
271  
272      void RegisterWindow(HWND hwnd) noexcept
273      {
274          runnerWindow = hwnd;
275      }
276  }