/ src / common / thread_worker.h
thread_worker.h
  1  // Copyright 2020 yuzu Emulator Project
  2  // Licensed under GPLv2 or any later version
  3  // Refer to the license.txt file included.
  4  
  5  #pragma once
  6  
  7  #include <atomic>
  8  #include <condition_variable>
  9  #include <functional>
 10  #include <mutex>
 11  #include <string>
 12  #include <thread>
 13  #include <type_traits>
 14  #include <vector>
 15  #include <queue>
 16  
 17  #include "common/polyfill_thread.h"
 18  #include "common/thread.h"
 19  #include "common/unique_function.h"
 20  
 21  namespace Common {
 22  
 23  template <class StateType = void>
 24  class StatefulThreadWorker {
 25      static constexpr bool with_state = !std::is_same_v<StateType, void>;
 26  
 27      struct DummyCallable {
 28          int operator()(std::size_t) const noexcept {
 29              return 0;
 30          }
 31      };
 32  
 33      using Task =
 34          std::conditional_t<with_state, UniqueFunction<void, StateType*>, UniqueFunction<void>>;
 35      using StateMaker =
 36          std::conditional_t<with_state, std::function<StateType(std::size_t)>, DummyCallable>;
 37  
 38  public:
 39      explicit StatefulThreadWorker(std::size_t num_workers, std::string_view name,
 40                                    StateMaker func = {})
 41          : workers_queued{num_workers}, thread_name{name} {
 42          const auto lambda = [this, func](std::stop_token stop_token, std::size_t index) {
 43              Common::SetCurrentThreadName(thread_name.data());
 44              {
 45                  [[maybe_unused]] std::conditional_t<with_state, StateType, int> state{func(index)};
 46                  while (!stop_token.stop_requested()) {
 47                      Task task;
 48                      {
 49                          std::unique_lock lock{queue_mutex};
 50                          if (requests.empty()) {
 51                              wait_condition.notify_all();
 52                          }
 53                          Common::CondvarWait(condition, lock, stop_token,
 54                                              [this] { return !requests.empty(); });
 55                          if (stop_token.stop_requested()) {
 56                              break;
 57                          }
 58                          task = std::move(requests.front());
 59                          requests.pop();
 60                      }
 61                      if constexpr (with_state) {
 62                          task(&state);
 63                      } else {
 64                          task();
 65                      }
 66                      ++work_done;
 67                  }
 68              }
 69              ++workers_stopped;
 70              wait_condition.notify_all();
 71          };
 72          threads.reserve(num_workers);
 73          for (std::size_t i = 0; i < num_workers; ++i) {
 74              threads.emplace_back(lambda, i);
 75          }
 76      }
 77  
 78      StatefulThreadWorker& operator=(const StatefulThreadWorker&) = delete;
 79      StatefulThreadWorker(const StatefulThreadWorker&) = delete;
 80  
 81      StatefulThreadWorker& operator=(StatefulThreadWorker&&) = delete;
 82      StatefulThreadWorker(StatefulThreadWorker&&) = delete;
 83  
 84      void QueueWork(Task work) {
 85          {
 86              std::unique_lock lock{queue_mutex};
 87              requests.emplace(std::move(work));
 88              ++work_scheduled;
 89          }
 90          condition.notify_one();
 91      }
 92  
 93      void WaitForRequests(std::stop_token stop_token = {}) {
 94          std::stop_callback callback(stop_token, [this] {
 95              for (auto& thread : threads) {
 96                  thread.request_stop();
 97              }
 98          });
 99          std::unique_lock lock{queue_mutex};
100          wait_condition.wait(lock, [this] {
101              return workers_stopped >= workers_queued || work_done >= work_scheduled;
102          });
103      }
104  
105      const std::size_t NumWorkers() const noexcept {
106          return threads.size();
107      }
108  
109  private:
110      std::queue<Task> requests;
111      std::mutex queue_mutex;
112      std::condition_variable_any condition;
113      std::condition_variable wait_condition;
114      std::atomic<std::size_t> work_scheduled{};
115      std::atomic<std::size_t> work_done{};
116      std::atomic<std::size_t> workers_stopped{};
117      std::atomic<std::size_t> workers_queued{};
118      std::string_view thread_name;
119      std::vector<std::jthread> threads;
120  };
121  
122  using ThreadWorker = StatefulThreadWorker<>;
123  
124  } // namespace Common