LoopbackCapture.cpp
1 #include "pch.h" 2 #include "LoopbackCapture.h" 3 #include <functiondiscoverykeys_devpkey.h> 4 5 #pragma comment(lib, "ole32.lib") 6 7 LoopbackCapture::LoopbackCapture() 8 { 9 m_stopEvent.create(wil::EventOptions::ManualReset); 10 m_samplesReadyEvent.create(wil::EventOptions::ManualReset); 11 } 12 13 LoopbackCapture::~LoopbackCapture() 14 { 15 Stop(); 16 if (m_pwfx) 17 { 18 CoTaskMemFree(m_pwfx); 19 m_pwfx = nullptr; 20 } 21 } 22 23 HRESULT LoopbackCapture::Initialize() 24 { 25 if (m_initialized.load()) 26 { 27 return S_OK; 28 } 29 30 HRESULT hr = CoCreateInstance( 31 __uuidof(MMDeviceEnumerator), 32 nullptr, 33 CLSCTX_ALL, 34 __uuidof(IMMDeviceEnumerator), 35 m_deviceEnumerator.put_void()); 36 if (FAILED(hr)) 37 { 38 return hr; 39 } 40 41 // Get the default audio render device (speakers/headphones) 42 hr = m_deviceEnumerator->GetDefaultAudioEndpoint(eRender, eConsole, m_device.put()); 43 if (FAILED(hr)) 44 { 45 return hr; 46 } 47 48 hr = m_device->Activate(__uuidof(IAudioClient), CLSCTX_ALL, nullptr, m_audioClient.put_void()); 49 if (FAILED(hr)) 50 { 51 return hr; 52 } 53 54 // Get the mix format 55 hr = m_audioClient->GetMixFormat(&m_pwfx); 56 if (FAILED(hr)) 57 { 58 return hr; 59 } 60 61 // Initialize audio client in loopback mode 62 // AUDCLNT_STREAMFLAGS_LOOPBACK enables capturing what's being played on the device 63 hr = m_audioClient->Initialize( 64 AUDCLNT_SHAREMODE_SHARED, 65 AUDCLNT_STREAMFLAGS_LOOPBACK, 66 1000000, // 100ms buffer to reduce capture latency 67 0, 68 m_pwfx, 69 nullptr); 70 if (FAILED(hr)) 71 { 72 return hr; 73 } 74 75 hr = m_audioClient->GetService(__uuidof(IAudioCaptureClient), m_captureClient.put_void()); 76 if (FAILED(hr)) 77 { 78 return hr; 79 } 80 81 m_initialized.store(true); 82 return S_OK; 83 } 84 85 HRESULT LoopbackCapture::Start() 86 { 87 if (!m_initialized.load()) 88 { 89 return E_NOT_VALID_STATE; 90 } 91 92 if (m_started.load()) 93 { 94 return S_OK; 95 } 96 97 m_stopEvent.ResetEvent(); 98 99 HRESULT hr = m_audioClient->Start(); 100 if (FAILED(hr)) 101 { 102 return hr; 103 } 104 105 m_started.store(true); 106 107 // Start capture thread 108 m_captureThread = std::thread(&LoopbackCapture::CaptureThread, this); 109 110 return S_OK; 111 } 112 113 void LoopbackCapture::Stop() 114 { 115 if (!m_started.load()) 116 { 117 return; 118 } 119 120 m_stopEvent.SetEvent(); 121 122 if (m_captureThread.joinable()) 123 { 124 m_captureThread.join(); 125 } 126 127 DrainCaptureClient(); 128 129 if (m_audioClient) 130 { 131 m_audioClient->Stop(); 132 } 133 134 m_started.store(false); 135 } 136 137 void LoopbackCapture::DrainCaptureClient() 138 { 139 if (!m_captureClient) 140 { 141 return; 142 } 143 144 while (true) 145 { 146 UINT32 packetLength = 0; 147 HRESULT hr = m_captureClient->GetNextPacketSize(&packetLength); 148 if (FAILED(hr) || packetLength == 0) 149 { 150 break; 151 } 152 153 BYTE* pData = nullptr; 154 UINT32 numFramesAvailable = 0; 155 DWORD flags = 0; 156 hr = m_captureClient->GetBuffer(&pData, &numFramesAvailable, &flags, nullptr, nullptr); 157 if (FAILED(hr)) 158 { 159 break; 160 } 161 162 if (numFramesAvailable > 0) 163 { 164 std::vector<float> samples; 165 166 if (m_pwfx->wFormatTag == WAVE_FORMAT_IEEE_FLOAT || 167 (m_pwfx->wFormatTag == WAVE_FORMAT_EXTENSIBLE && 168 reinterpret_cast<WAVEFORMATEXTENSIBLE*>(m_pwfx)->SubFormat == KSDATAFORMAT_SUBTYPE_IEEE_FLOAT)) 169 { 170 if (flags & AUDCLNT_BUFFERFLAGS_SILENT) 171 { 172 samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels, 0.0f); 173 } 174 else 175 { 176 float* floatData = reinterpret_cast<float*>(pData); 177 samples.assign(floatData, floatData + (static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels)); 178 } 179 } 180 else if (m_pwfx->wFormatTag == WAVE_FORMAT_PCM || 181 (m_pwfx->wFormatTag == WAVE_FORMAT_EXTENSIBLE && 182 reinterpret_cast<WAVEFORMATEXTENSIBLE*>(m_pwfx)->SubFormat == KSDATAFORMAT_SUBTYPE_PCM)) 183 { 184 if (flags & AUDCLNT_BUFFERFLAGS_SILENT) 185 { 186 samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels, 0.0f); 187 } 188 else if (m_pwfx->wBitsPerSample == 16) 189 { 190 int16_t* pcmData = reinterpret_cast<int16_t*>(pData); 191 samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels); 192 for (size_t i = 0; i < samples.size(); i++) 193 { 194 samples[i] = static_cast<float>(pcmData[i]) / 32768.0f; 195 } 196 } 197 else if (m_pwfx->wBitsPerSample == 32) 198 { 199 int32_t* pcmData = reinterpret_cast<int32_t*>(pData); 200 samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels); 201 for (size_t i = 0; i < samples.size(); i++) 202 { 203 samples[i] = static_cast<float>(pcmData[i]) / 2147483648.0f; 204 } 205 } 206 } 207 208 if (!samples.empty()) 209 { 210 auto lock = m_lock.lock_exclusive(); 211 m_sampleQueue.push_back(std::move(samples)); 212 m_samplesReadyEvent.SetEvent(); 213 } 214 } 215 216 hr = m_captureClient->ReleaseBuffer(numFramesAvailable); 217 if (FAILED(hr)) 218 { 219 break; 220 } 221 } 222 } 223 224 void LoopbackCapture::CaptureThread() 225 { 226 while (WaitForSingleObject(m_stopEvent.get(), 10) == WAIT_TIMEOUT) 227 { 228 UINT32 packetLength = 0; 229 HRESULT hr = m_captureClient->GetNextPacketSize(&packetLength); 230 if (FAILED(hr)) 231 { 232 break; 233 } 234 235 while (packetLength != 0) 236 { 237 BYTE* pData = nullptr; 238 UINT32 numFramesAvailable = 0; 239 DWORD flags = 0; 240 241 hr = m_captureClient->GetBuffer(&pData, &numFramesAvailable, &flags, nullptr, nullptr); 242 if (FAILED(hr)) 243 { 244 break; 245 } 246 247 if (numFramesAvailable > 0) 248 { 249 std::vector<float> samples; 250 251 // Convert to float samples 252 if (m_pwfx->wFormatTag == WAVE_FORMAT_IEEE_FLOAT || 253 (m_pwfx->wFormatTag == WAVE_FORMAT_EXTENSIBLE && 254 reinterpret_cast<WAVEFORMATEXTENSIBLE*>(m_pwfx)->SubFormat == KSDATAFORMAT_SUBTYPE_IEEE_FLOAT)) 255 { 256 // Already float format 257 if (flags & AUDCLNT_BUFFERFLAGS_SILENT) 258 { 259 // Insert silence 260 samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels, 0.0f); 261 } 262 else 263 { 264 float* floatData = reinterpret_cast<float*>(pData); 265 samples.assign(floatData, floatData + (static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels)); 266 } 267 } 268 else if (m_pwfx->wFormatTag == WAVE_FORMAT_PCM || 269 (m_pwfx->wFormatTag == WAVE_FORMAT_EXTENSIBLE && 270 reinterpret_cast<WAVEFORMATEXTENSIBLE*>(m_pwfx)->SubFormat == KSDATAFORMAT_SUBTYPE_PCM)) 271 { 272 // Convert PCM to float 273 if (flags & AUDCLNT_BUFFERFLAGS_SILENT) 274 { 275 samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels, 0.0f); 276 } 277 else if (m_pwfx->wBitsPerSample == 16) 278 { 279 int16_t* pcmData = reinterpret_cast<int16_t*>(pData); 280 samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels); 281 for (size_t i = 0; i < samples.size(); i++) 282 { 283 samples[i] = static_cast<float>(pcmData[i]) / 32768.0f; 284 } 285 } 286 else if (m_pwfx->wBitsPerSample == 32) 287 { 288 int32_t* pcmData = reinterpret_cast<int32_t*>(pData); 289 samples.resize(static_cast<size_t>(numFramesAvailable) * m_pwfx->nChannels); 290 for (size_t i = 0; i < samples.size(); i++) 291 { 292 samples[i] = static_cast<float>(pcmData[i]) / 2147483648.0f; 293 } 294 } 295 } 296 297 if (!samples.empty()) 298 { 299 auto lock = m_lock.lock_exclusive(); 300 m_sampleQueue.push_back(std::move(samples)); 301 m_samplesReadyEvent.SetEvent(); 302 } 303 } 304 305 hr = m_captureClient->ReleaseBuffer(numFramesAvailable); 306 if (FAILED(hr)) 307 { 308 break; 309 } 310 311 hr = m_captureClient->GetNextPacketSize(&packetLength); 312 if (FAILED(hr)) 313 { 314 break; 315 } 316 } 317 } 318 } 319 320 bool LoopbackCapture::TryGetSamples(std::vector<float>& samples) 321 { 322 auto lock = m_lock.lock_exclusive(); 323 if (m_sampleQueue.empty()) 324 { 325 return false; 326 } 327 328 samples = std::move(m_sampleQueue.front()); 329 m_sampleQueue.pop_front(); 330 331 if (m_sampleQueue.empty()) 332 { 333 m_samplesReadyEvent.ResetEvent(); 334 } 335 336 return true; 337 }