/ src / modules / ZoomIt / ZoomIt / LoopbackCapture.cpp
LoopbackCapture.cpp
  1  #include "pch.h"
  2  #include "LoopbackCapture.h"
  3  #include <functiondiscoverykeys_devpkey.h>
  4  
  5  #pragma comment(lib, "ole32.lib")
  6  
  7  LoopbackCapture::LoopbackCapture()
  8  {
  9      m_stopEvent.create(wil::EventOptions::ManualReset);
 10      m_samplesReadyEvent.create(wil::EventOptions::ManualReset);
 11  }
 12  
 13  LoopbackCapture::~LoopbackCapture()
 14  {
 15      Stop();
 16      if (m_pwfx)
 17      {
 18          CoTaskMemFree(m_pwfx);
 19          m_pwfx = nullptr;
 20      }
 21  }
 22  
 23  HRESULT LoopbackCapture::Initialize()
 24  {
 25      if (m_initialized.load())
 26      {
 27          return S_OK;
 28      }
 29  
 30      HRESULT hr = CoCreateInstance(
 31          __uuidof(MMDeviceEnumerator),
 32          nullptr,
 33          CLSCTX_ALL,
 34          __uuidof(IMMDeviceEnumerator),
 35          m_deviceEnumerator.put_void());
 36      if (FAILED(hr))
 37      {
 38          return hr;
 39      }
 40  
 41      // Get the default audio render device (speakers/headphones)
 42      hr = m_deviceEnumerator->GetDefaultAudioEndpoint(eRender, eConsole, m_device.put());
 43      if (FAILED(hr))
 44      {
 45          return hr;
 46      }
 47  
 48      hr = m_device->Activate(__uuidof(IAudioClient), CLSCTX_ALL, nullptr, m_audioClient.put_void());
 49      if (FAILED(hr))
 50      {
 51          return hr;
 52      }
 53  
 54      // Get the mix format
 55      hr = m_audioClient->GetMixFormat(&m_pwfx);
 56      if (FAILED(hr))
 57      {
 58          return hr;
 59      }
 60  
 61      // Initialize audio client in loopback mode
 62      // AUDCLNT_STREAMFLAGS_LOOPBACK enables capturing what's being played on the device
 63      hr = m_audioClient->Initialize(
 64          AUDCLNT_SHAREMODE_SHARED,
 65          AUDCLNT_STREAMFLAGS_LOOPBACK,
 66          1000000, // 100ms buffer to reduce capture latency
 67          0,
 68          m_pwfx,
 69          nullptr);
 70      if (FAILED(hr))
 71      {
 72          return hr;
 73      }
 74  
 75      hr = m_audioClient->GetService(__uuidof(IAudioCaptureClient), m_captureClient.put_void());
 76      if (FAILED(hr))
 77      {
 78          return hr;
 79      }
 80  
 81      m_initialized.store(true);
 82      return S_OK;
 83  }
 84  
 85  HRESULT LoopbackCapture::Start()
 86  {
 87      if (!m_initialized.load())
 88      {
 89          return E_NOT_VALID_STATE;
 90      }
 91  
 92      if (m_started.load())
 93      {
 94          return S_OK;
 95      }
 96  
 97      m_stopEvent.ResetEvent();
 98  
 99      HRESULT hr = m_audioClient->Start();
100      if (FAILED(hr))
101      {
102          return hr;
103      }
104  
105      m_started.store(true);
106  
107      // Start capture thread
108      m_captureThread = std::thread(&LoopbackCapture::CaptureThread, this);
109  
110      return S_OK;
111  }
112  
113  void LoopbackCapture::Stop()
114  {
115      if (!m_started.load())
116      {
117          return;
118      }
119  
120      m_stopEvent.SetEvent();
121  
122      if (m_captureThread.joinable())
123      {
124          m_captureThread.join();
125      }
126  
127      DrainCaptureClient();
128  
129      if (m_audioClient)
130      {
131          m_audioClient->Stop();
132      }
133  
134      m_started.store(false);
135  }
136  
137  void LoopbackCapture::DrainCaptureClient()
138  {
139      if (!m_captureClient)
140      {
141          return;
142      }
143  
144      while (true)
145      {
146          UINT32 packetLength = 0;
147          HRESULT hr = m_captureClient->GetNextPacketSize(&packetLength);
148          if (FAILED(hr) || packetLength == 0)
149          {
150              break;
151          }
152  
153          BYTE* pData = nullptr;
154          UINT32 numFramesAvailable = 0;
155          DWORD flags = 0;
156          hr = m_captureClient->GetBuffer(&pData, &numFramesAvailable, &flags, nullptr, nullptr);
157          if (FAILED(hr))
158          {
159              break;
160          }
161  
162          if (numFramesAvailable > 0)
163          {
164              std::vector<float> samples;
165  
166              if (m_pwfx->wFormatTag == WAVE_FORMAT_IEEE_FLOAT ||
167                  (m_pwfx->wFormatTag == WAVE_FORMAT_EXTENSIBLE &&
168                   reinterpret_cast<WAVEFORMATEXTENSIBLE*>(m_pwfx)->SubFormat == KSDATAFORMAT_SUBTYPE_IEEE_FLOAT))
169              {
170                  if (flags & AUDCLNT_BUFFERFLAGS_SILENT)
171                  {
172                      samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels, 0.0f);
173                  }
174                  else
175                  {
176                      float* floatData = reinterpret_cast<float*>(pData);
177                      samples.assign(floatData, floatData + (static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels));
178                  }
179              }
180              else if (m_pwfx->wFormatTag == WAVE_FORMAT_PCM ||
181                       (m_pwfx->wFormatTag == WAVE_FORMAT_EXTENSIBLE &&
182                        reinterpret_cast<WAVEFORMATEXTENSIBLE*>(m_pwfx)->SubFormat == KSDATAFORMAT_SUBTYPE_PCM))
183              {
184                  if (flags & AUDCLNT_BUFFERFLAGS_SILENT)
185                  {
186                      samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels, 0.0f);
187                  }
188                  else if (m_pwfx->wBitsPerSample == 16)
189                  {
190                      int16_t* pcmData = reinterpret_cast<int16_t*>(pData);
191                      samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels);
192                      for (size_t i = 0; i < samples.size(); i++)
193                      {
194                          samples[i] = static_cast<float>(pcmData[i]) / 32768.0f;
195                      }
196                  }
197                  else if (m_pwfx->wBitsPerSample == 32)
198                  {
199                      int32_t* pcmData = reinterpret_cast<int32_t*>(pData);
200                      samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels);
201                      for (size_t i = 0; i < samples.size(); i++)
202                      {
203                          samples[i] = static_cast<float>(pcmData[i]) / 2147483648.0f;
204                      }
205                  }
206              }
207  
208              if (!samples.empty())
209              {
210                  auto lock = m_lock.lock_exclusive();
211                  m_sampleQueue.push_back(std::move(samples));
212                  m_samplesReadyEvent.SetEvent();
213              }
214          }
215  
216          hr = m_captureClient->ReleaseBuffer(numFramesAvailable);
217          if (FAILED(hr))
218          {
219              break;
220          }
221      }
222  }
223  
224  void LoopbackCapture::CaptureThread()
225  {
226      while (WaitForSingleObject(m_stopEvent.get(), 10) == WAIT_TIMEOUT)
227      {
228          UINT32 packetLength = 0;
229          HRESULT hr = m_captureClient->GetNextPacketSize(&packetLength);
230          if (FAILED(hr))
231          {
232              break;
233          }
234  
235          while (packetLength != 0)
236          {
237              BYTE* pData = nullptr;
238              UINT32 numFramesAvailable = 0;
239              DWORD flags = 0;
240  
241              hr = m_captureClient->GetBuffer(&pData, &numFramesAvailable, &flags, nullptr, nullptr);
242              if (FAILED(hr))
243              {
244                  break;
245              }
246  
247              if (numFramesAvailable > 0)
248              {
249                  std::vector<float> samples;
250  
251                  // Convert to float samples
252                  if (m_pwfx->wFormatTag == WAVE_FORMAT_IEEE_FLOAT ||
253                      (m_pwfx->wFormatTag == WAVE_FORMAT_EXTENSIBLE &&
254                       reinterpret_cast<WAVEFORMATEXTENSIBLE*>(m_pwfx)->SubFormat == KSDATAFORMAT_SUBTYPE_IEEE_FLOAT))
255                  {
256                      // Already float format
257                      if (flags & AUDCLNT_BUFFERFLAGS_SILENT)
258                      {
259                          // Insert silence
260                          samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels, 0.0f);
261                      }
262                      else
263                      {
264                          float* floatData = reinterpret_cast<float*>(pData);
265                          samples.assign(floatData, floatData + (static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels));
266                      }
267                  }
268                  else if (m_pwfx->wFormatTag == WAVE_FORMAT_PCM ||
269                           (m_pwfx->wFormatTag == WAVE_FORMAT_EXTENSIBLE &&
270                            reinterpret_cast<WAVEFORMATEXTENSIBLE*>(m_pwfx)->SubFormat == KSDATAFORMAT_SUBTYPE_PCM))
271                  {
272                      // Convert PCM to float
273                      if (flags & AUDCLNT_BUFFERFLAGS_SILENT)
274                      {
275                          samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels, 0.0f);
276                      }
277                      else if (m_pwfx->wBitsPerSample == 16)
278                      {
279                          int16_t* pcmData = reinterpret_cast<int16_t*>(pData);
280                          samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels);
281                          for (size_t i = 0; i < samples.size(); i++)
282                          {
283                              samples[i] = static_cast<float>(pcmData[i]) / 32768.0f;
284                          }
285                      }
286                      else if (m_pwfx->wBitsPerSample == 32)
287                      {
288                          int32_t* pcmData = reinterpret_cast<int32_t*>(pData);
289                          samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels);
290                          for (size_t i = 0; i < samples.size(); i++)
291                          {
292                              samples[i] = static_cast<float>(pcmData[i]) / 2147483648.0f;
293                          }
294                      }
295                  }
296  
297                  if (!samples.empty())
298                  {
299                      auto lock = m_lock.lock_exclusive();
300                      m_sampleQueue.push_back(std::move(samples));
301                      m_samplesReadyEvent.SetEvent();
302                  }
303              }
304  
305              hr = m_captureClient->ReleaseBuffer(numFramesAvailable);
306              if (FAILED(hr))
307              {
308                  break;
309              }
310  
311              hr = m_captureClient->GetNextPacketSize(&packetLength);
312              if (FAILED(hr))
313              {
314                  break;
315              }
316          }
317      }
318  }
319  
320  bool LoopbackCapture::TryGetSamples(std::vector<float>& samples)
321  {
322      auto lock = m_lock.lock_exclusive();
323      if (m_sampleQueue.empty())
324      {
325          return false;
326      }
327  
328      samples = std::move(m_sampleQueue.front());
329      m_sampleQueue.pop_front();
330  
331      if (m_sampleQueue.empty())
332      {
333          m_samplesReadyEvent.ResetEvent();
334      }
335  
336      return true;
337  }