CLILogic.cpp
  1  #include "pch.h"
  2  #include "CLILogic.h"
  3  #include <common/utils/json.h>
  4  #include <iostream>
  5  #include <sstream>
  6  #include <chrono>
  7  #include "resource.h"
  8  #include <common/logger/logger.h>
  9  #include <common/utils/logger_helper.h>
 10  #include <type_traits>
 11  
 12  template<typename T>
 13  DWORD_PTR ToDwordPtr(T val)
 14  {
 15      if constexpr (std::is_pointer_v<T>)
 16      {
 17          return reinterpret_cast<DWORD_PTR>(val);
 18      }
 19      else
 20      {
 21          return static_cast<DWORD_PTR>(val);
 22      }
 23  }
 24  
 25  template<typename... Args>
 26  std::wstring FormatString(IStringProvider& strings, UINT id, Args... args)
 27  {
 28      std::wstring format = strings.GetString(id);
 29      if (format.empty()) return L"";
 30  
 31      DWORD_PTR arguments[] = { ToDwordPtr(args)..., 0 };
 32  
 33      LPWSTR buffer = nullptr;
 34      FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_STRING | FORMAT_MESSAGE_ARGUMENT_ARRAY,
 35                     format.c_str(),
 36                     0,
 37                     0,
 38                     reinterpret_cast<LPWSTR>(&buffer),
 39                     0,
 40                     reinterpret_cast<va_list*>(arguments));
 41  
 42      if (buffer)
 43      {
 44          std::wstring result(buffer);
 45          LocalFree(buffer);
 46          return result;
 47      }
 48      return L"";
 49  }
 50  
 51  std::wstring get_usage(IStringProvider& strings)
 52  {
 53      return strings.GetString(IDS_USAGE);
 54  }
 55  
 56  std::wstring get_json(const std::vector<ProcessResult>& results)
 57  {
 58      json::JsonObject root;
 59      json::JsonArray processes;
 60  
 61      for (const auto& result : results)
 62      {
 63          json::JsonObject process;
 64          process.SetNamedValue(L"pid", json::JsonValue::CreateNumberValue(result.pid));
 65          process.SetNamedValue(L"name", json::JsonValue::CreateStringValue(result.name));
 66          process.SetNamedValue(L"user", json::JsonValue::CreateStringValue(result.user));
 67          
 68          json::JsonArray files;
 69          for (const auto& file : result.files)
 70          {
 71              files.Append(json::JsonValue::CreateStringValue(file));
 72          }
 73          process.SetNamedValue(L"files", files);
 74          
 75          processes.Append(process);
 76      }
 77  
 78      root.SetNamedValue(L"processes", processes);
 79      return root.Stringify().c_str();
 80  }
 81  
 82  std::wstring get_text(const std::vector<ProcessResult>& results, IStringProvider& strings)
 83  {
 84      std::wstringstream ss;
 85      if (results.empty())
 86      {
 87          ss << strings.GetString(IDS_NO_PROCESSES);
 88          return ss.str();
 89      }
 90  
 91      ss << strings.GetString(IDS_HEADER);
 92      for (const auto& result : results)
 93      {
 94          ss << result.pid << L"\t" 
 95             << result.user << L"\t" 
 96             << result.name << std::endl;
 97      }
 98      return ss.str();
 99  }
100  
101  std::wstring kill_processes(const std::vector<ProcessResult>& results, IProcessTerminator& terminator, IStringProvider& strings)
102  {
103      std::wstringstream ss;
104      for (const auto& result : results)
105      {
106          if (terminator.terminate(result.pid))
107          {
108              ss << FormatString(strings, IDS_TERMINATED, result.pid, result.name.c_str());
109          }
110          else
111          {
112              ss << FormatString(strings, IDS_FAILED_TERMINATE, result.pid, result.name.c_str());
113          }
114      }
115      return ss.str();
116  }
117  
118  CommandResult run_command(int argc, wchar_t* argv[], IProcessFinder& finder, IProcessTerminator& terminator, IStringProvider& strings)
119  {
120      Logger::info("Parsing arguments");
121      if (argc < 2)
122      {
123          Logger::warn("No arguments provided");
124          return { 1, get_usage(strings) };
125      }
126  
127      bool json_output = false;
128      bool kill = false;
129      bool wait = false;
130      int timeout_ms = -1;
131      std::vector<std::wstring> paths;
132  
133      for (int i = 1; i < argc; ++i)
134      {
135          std::wstring arg = argv[i];
136          if (arg == L"--json")
137          {
138              json_output = true;
139          }
140          else if (arg == L"--kill")
141          {
142              kill = true;
143          }
144          else if (arg == L"--wait")
145          {
146              wait = true;
147          }
148          else if (arg == L"--timeout")
149          {
150              if (i + 1 < argc)
151              {
152                  try
153                  {
154                      timeout_ms = std::stoi(argv[++i]);
155                  }
156                  catch (...)
157                  {
158                      Logger::error("Invalid timeout value");
159                      return { 1, strings.GetString(IDS_ERROR_INVALID_TIMEOUT) };
160                  }
161              }
162              else
163              {
164                  Logger::error("Timeout argument missing");
165                  return { 1, strings.GetString(IDS_ERROR_TIMEOUT_ARG) };
166              }
167          }
168          else if (arg == L"--help")
169          {
170              return { 0, get_usage(strings) };
171          }
172          else
173          {
174              paths.push_back(arg);
175          }
176      }
177  
178      if (paths.empty())
179      {
180          Logger::error("No paths specified");
181          return { 1, strings.GetString(IDS_ERROR_NO_PATHS) };
182      }
183  
184      Logger::info("Processing {} paths", paths.size());
185  
186      if (wait)
187      {
188          std::wstringstream ss;
189          if (json_output)
190          {
191               Logger::warn("Wait is incompatible with JSON output");
192               ss << strings.GetString(IDS_WARN_JSON_WAIT);
193               json_output = false;
194          }
195          
196          ss << strings.GetString(IDS_WAITING);
197          auto start_time = std::chrono::steady_clock::now();
198          while (true)
199          {
200              auto results = finder.find(paths);
201              if (results.empty())
202              {
203                  Logger::info("Files unlocked");
204                  ss << strings.GetString(IDS_UNLOCKED);
205                  break;
206              }
207  
208              if (timeout_ms >= 0)
209              {
210                  auto current_time = std::chrono::steady_clock::now();
211                  auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(current_time - start_time).count();
212                  if (elapsed > timeout_ms)
213                  {
214                      Logger::warn("Timeout waiting for files to be unlocked");
215                      ss << strings.GetString(IDS_TIMEOUT);
216                      return { 1, ss.str() };
217                  }
218              }
219  
220              Sleep(200);
221          }
222          return { 0, ss.str() };
223      }
224  
225      auto results = finder.find(paths);
226      Logger::info("Found {} processes locking the files", results.size());
227      std::wstringstream output_ss;
228  
229      if (kill)
230      {
231          Logger::info("Killing processes");
232          output_ss << kill_processes(results, terminator, strings);
233          // Re-check after killing
234          results = finder.find(paths);
235          Logger::info("Remaining processes: {}", results.size());
236      }
237  
238      if (json_output)
239      {
240          output_ss << get_json(results) << std::endl;
241      }
242      else
243      {
244          output_ss << get_text(results, strings);
245      }
246  
247      return { 0, output_ss.str() };
248  }