/ src / discord_rpc.cpp
discord_rpc.cpp
  1  #include "discord_rpc.h"
  2  
  3  #include "backoff.h"
  4  #include "discord_register.h"
  5  #include "msg_queue.h"
  6  #include "rpc_connection.h"
  7  #include "serialization.h"
  8  
  9  #include <atomic>
 10  #include <chrono>
 11  #include <mutex>
 12  
 13  #ifndef DISCORD_DISABLE_IO_THREAD
 14  #include <condition_variable>
 15  #include <thread>
 16  #endif
 17  
 18  constexpr size_t MaxMessageSize{16 * 1024};
 19  constexpr size_t MessageQueueSize{8};
 20  constexpr size_t JoinQueueSize{8};
 21  
 22  struct QueuedMessage {
 23      size_t length;
 24      char buffer[MaxMessageSize];
 25  
 26      void Copy(const QueuedMessage& other)
 27      {
 28          length = other.length;
 29          if (length) {
 30              memcpy(buffer, other.buffer, length);
 31          }
 32      }
 33  };
 34  
 35  struct User {
 36      // snowflake (64bit int), turned into a ascii decimal string, at most 20 chars +1 null
 37      // terminator = 21
 38      char userId[32];
 39      // 32 unicode glyphs is max name size => 4 bytes per glyph in the worst case, +1 for null
 40      // terminator = 129
 41      char username[344];
 42      // 4 decimal digits + 1 null terminator = 5
 43      char discriminator[8];
 44      // optional 'a_' + md5 hex digest (32 bytes) + null terminator = 35
 45      char avatar[128];
 46      // Rounded way up because I'm paranoid about games breaking from future changes in these sizes
 47  };
 48  
 49  static RpcConnection* Connection{nullptr};
 50  static DiscordEventHandlers QueuedHandlers{};
 51  static DiscordEventHandlers Handlers{};
 52  static std::atomic_bool WasJustConnected{false};
 53  static std::atomic_bool WasJustDisconnected{false};
 54  static std::atomic_bool GotErrorMessage{false};
 55  static std::atomic_bool WasJoinGame{false};
 56  static std::atomic_bool WasSpectateGame{false};
 57  static std::atomic_bool UpdatePresence{false};
 58  static char JoinGameSecret[256];
 59  static char SpectateGameSecret[256];
 60  static int LastErrorCode{0};
 61  static char LastErrorMessage[256];
 62  static int LastDisconnectErrorCode{0};
 63  static char LastDisconnectErrorMessage[256];
 64  static std::mutex PresenceMutex;
 65  static std::mutex HandlerMutex;
 66  static QueuedMessage QueuedPresence{};
 67  static MsgQueue<QueuedMessage, MessageQueueSize> SendQueue;
 68  static MsgQueue<User, JoinQueueSize> JoinAskQueue;
 69  static User connectedUser;
 70  
 71  // We want to auto connect, and retry on failure, but not as fast as possible. This does expoential
 72  // backoff from 0.5 seconds to 1 minute
 73  static Backoff ReconnectTimeMs(500, 60 * 1000);
 74  static auto NextConnect = std::chrono::system_clock::now();
 75  static int Pid{0};
 76  static int Nonce{1};
 77  
 78  #ifndef DISCORD_DISABLE_IO_THREAD
 79  static void Discord_UpdateConnection(void);
 80  class IoThreadHolder {
 81  private:
 82      std::atomic_bool keepRunning{true};
 83      std::mutex waitForIOMutex;
 84      std::condition_variable waitForIOActivity;
 85      std::thread ioThread;
 86  
 87  public:
 88      void Start()
 89      {
 90          keepRunning.store(true);
 91          ioThread = std::thread([&]() {
 92              const std::chrono::duration<int64_t, std::milli> maxWait{500LL};
 93              Discord_UpdateConnection();
 94              while (keepRunning.load()) {
 95                  std::unique_lock<std::mutex> lock(waitForIOMutex);
 96                  waitForIOActivity.wait_for(lock, maxWait);
 97                  Discord_UpdateConnection();
 98              }
 99          });
100      }
101  
102      void Notify() { waitForIOActivity.notify_all(); }
103  
104      void Stop()
105      {
106          keepRunning.exchange(false);
107          Notify();
108          if (ioThread.joinable()) {
109              ioThread.join();
110          }
111      }
112  
113      ~IoThreadHolder() { Stop(); }
114  };
115  #else
116  class IoThreadHolder {
117  public:
118      void Start() {}
119      void Stop() {}
120      void Notify() {}
121  };
122  #endif // DISCORD_DISABLE_IO_THREAD
123  static IoThreadHolder* IoThread{nullptr};
124  
125  static void UpdateReconnectTime()
126  {
127      NextConnect = std::chrono::system_clock::now() +
128        std::chrono::duration<int64_t, std::milli>{ReconnectTimeMs.nextDelay()};
129  }
130  
131  #ifdef DISCORD_DISABLE_IO_THREAD
132  extern "C" DISCORD_EXPORT void Discord_UpdateConnection(void)
133  #else
134  static void Discord_UpdateConnection(void)
135  #endif
136  {
137      if (!Connection) {
138          return;
139      }
140  
141      if (!Connection->IsOpen()) {
142          if (std::chrono::system_clock::now() >= NextConnect) {
143              UpdateReconnectTime();
144              Connection->Open();
145          }
146      }
147      else {
148          // reads
149  
150          for (;;) {
151              JsonDocument message;
152  
153              if (!Connection->Read(message)) {
154                  break;
155              }
156  
157              const char* evtName = GetStrMember(&message, "evt");
158              const char* nonce = GetStrMember(&message, "nonce");
159  
160              if (nonce) {
161                  // in responses only -- should use to match up response when needed.
162  
163                  if (evtName && strcmp(evtName, "ERROR") == 0) {
164                      auto data = GetObjMember(&message, "data");
165                      LastErrorCode = GetIntMember(data, "code");
166                      StringCopy(LastErrorMessage, GetStrMember(data, "message", ""));
167                      GotErrorMessage.store(true);
168                  }
169              }
170              else {
171                  // should have evt == name of event, optional data
172                  if (evtName == nullptr) {
173                      continue;
174                  }
175  
176                  auto data = GetObjMember(&message, "data");
177  
178                  if (strcmp(evtName, "ACTIVITY_JOIN") == 0) {
179                      auto secret = GetStrMember(data, "secret");
180                      if (secret) {
181                          StringCopy(JoinGameSecret, secret);
182                          WasJoinGame.store(true);
183                      }
184                  }
185                  else if (strcmp(evtName, "ACTIVITY_SPECTATE") == 0) {
186                      auto secret = GetStrMember(data, "secret");
187                      if (secret) {
188                          StringCopy(SpectateGameSecret, secret);
189                          WasSpectateGame.store(true);
190                      }
191                  }
192                  else if (strcmp(evtName, "ACTIVITY_JOIN_REQUEST") == 0) {
193                      auto user = GetObjMember(data, "user");
194                      auto userId = GetStrMember(user, "id");
195                      auto username = GetStrMember(user, "username");
196                      auto avatar = GetStrMember(user, "avatar");
197                      auto joinReq = JoinAskQueue.GetNextAddMessage();
198                      if (userId && username && joinReq) {
199                          StringCopy(joinReq->userId, userId);
200                          StringCopy(joinReq->username, username);
201                          auto discriminator = GetStrMember(user, "discriminator");
202                          if (discriminator) {
203                              StringCopy(joinReq->discriminator, discriminator);
204                          }
205                          if (avatar) {
206                              StringCopy(joinReq->avatar, avatar);
207                          }
208                          else {
209                              joinReq->avatar[0] = 0;
210                          }
211                          JoinAskQueue.CommitAdd();
212                      }
213                  }
214              }
215          }
216  
217          // writes
218          if (UpdatePresence.exchange(false) && QueuedPresence.length) {
219              QueuedMessage local;
220              {
221                  std::lock_guard<std::mutex> guard(PresenceMutex);
222                  local.Copy(QueuedPresence);
223              }
224              if (!Connection->Write(local.buffer, local.length)) {
225                  // if we fail to send, requeue
226                  std::lock_guard<std::mutex> guard(PresenceMutex);
227                  QueuedPresence.Copy(local);
228                  UpdatePresence.exchange(true);
229              }
230          }
231  
232          while (SendQueue.HavePendingSends()) {
233              auto qmessage = SendQueue.GetNextSendMessage();
234              Connection->Write(qmessage->buffer, qmessage->length);
235              SendQueue.CommitSend();
236          }
237      }
238  }
239  
240  static void SignalIOActivity()
241  {
242      if (IoThread != nullptr) {
243          IoThread->Notify();
244      }
245  }
246  
247  static bool RegisterForEvent(const char* evtName)
248  {
249      auto qmessage = SendQueue.GetNextAddMessage();
250      if (qmessage) {
251          qmessage->length =
252            JsonWriteSubscribeCommand(qmessage->buffer, sizeof(qmessage->buffer), Nonce++, evtName);
253          SendQueue.CommitAdd();
254          SignalIOActivity();
255          return true;
256      }
257      return false;
258  }
259  
260  static bool DeregisterForEvent(const char* evtName)
261  {
262      auto qmessage = SendQueue.GetNextAddMessage();
263      if (qmessage) {
264          qmessage->length =
265            JsonWriteUnsubscribeCommand(qmessage->buffer, sizeof(qmessage->buffer), Nonce++, evtName);
266          SendQueue.CommitAdd();
267          SignalIOActivity();
268          return true;
269      }
270      return false;
271  }
272  
273  extern "C" DISCORD_EXPORT void Discord_Initialize(const char* applicationId,
274                                                    DiscordEventHandlers* handlers,
275                                                    int autoRegister,
276                                                    const char* optionalSteamId)
277  {
278      IoThread = new (std::nothrow) IoThreadHolder();
279      if (IoThread == nullptr) {
280          return;
281      }
282  
283      if (autoRegister) {
284          if (optionalSteamId && optionalSteamId[0]) {
285              Discord_RegisterSteamGame(applicationId, optionalSteamId);
286          }
287          else {
288              Discord_Register(applicationId, nullptr);
289          }
290      }
291  
292      Pid = GetProcessId();
293  
294      {
295          std::lock_guard<std::mutex> guard(HandlerMutex);
296  
297          if (handlers) {
298              QueuedHandlers = *handlers;
299          }
300          else {
301              QueuedHandlers = {};
302          }
303  
304          Handlers = {};
305      }
306  
307      if (Connection) {
308          return;
309      }
310  
311      Connection = RpcConnection::Create(applicationId);
312      Connection->onConnect = [](JsonDocument& readyMessage) {
313          Discord_UpdateHandlers(&QueuedHandlers);
314          if (QueuedPresence.length > 0) {
315              UpdatePresence.exchange(true);
316              SignalIOActivity();
317          }
318          auto data = GetObjMember(&readyMessage, "data");
319          auto user = GetObjMember(data, "user");
320          auto userId = GetStrMember(user, "id");
321          auto username = GetStrMember(user, "username");
322          auto avatar = GetStrMember(user, "avatar");
323          if (userId && username) {
324              StringCopy(connectedUser.userId, userId);
325              StringCopy(connectedUser.username, username);
326              auto discriminator = GetStrMember(user, "discriminator");
327              if (discriminator) {
328                  StringCopy(connectedUser.discriminator, discriminator);
329              }
330              if (avatar) {
331                  StringCopy(connectedUser.avatar, avatar);
332              }
333              else {
334                  connectedUser.avatar[0] = 0;
335              }
336          }
337          WasJustConnected.exchange(true);
338          ReconnectTimeMs.reset();
339      };
340      Connection->onDisconnect = [](int err, const char* message) {
341          LastDisconnectErrorCode = err;
342          StringCopy(LastDisconnectErrorMessage, message);
343          WasJustDisconnected.exchange(true);
344          UpdateReconnectTime();
345      };
346  
347      IoThread->Start();
348  }
349  
350  extern "C" DISCORD_EXPORT void Discord_Shutdown(void)
351  {
352      if (!Connection) {
353          return;
354      }
355      Connection->onConnect = nullptr;
356      Connection->onDisconnect = nullptr;
357      Handlers = {};
358      QueuedPresence.length = 0;
359      UpdatePresence.exchange(false);
360      if (IoThread != nullptr) {
361          IoThread->Stop();
362          delete IoThread;
363          IoThread = nullptr;
364      }
365  
366      RpcConnection::Destroy(Connection);
367  }
368  
369  extern "C" DISCORD_EXPORT void Discord_UpdatePresence(const DiscordRichPresence* presence)
370  {
371      {
372          std::lock_guard<std::mutex> guard(PresenceMutex);
373          QueuedPresence.length = JsonWriteRichPresenceObj(
374            QueuedPresence.buffer, sizeof(QueuedPresence.buffer), Nonce++, Pid, presence);
375          UpdatePresence.exchange(true);
376      }
377      SignalIOActivity();
378  }
379  
380  extern "C" DISCORD_EXPORT void Discord_ClearPresence(void)
381  {
382      Discord_UpdatePresence(nullptr);
383  }
384  
385  extern "C" DISCORD_EXPORT void Discord_Respond(const char* userId, /* DISCORD_REPLY_ */ int reply)
386  {
387      // if we are not connected, let's not batch up stale messages for later
388      if (!Connection || !Connection->IsOpen()) {
389          return;
390      }
391      auto qmessage = SendQueue.GetNextAddMessage();
392      if (qmessage) {
393          qmessage->length =
394            JsonWriteJoinReply(qmessage->buffer, sizeof(qmessage->buffer), userId, reply, Nonce++);
395          SendQueue.CommitAdd();
396          SignalIOActivity();
397      }
398  }
399  
400  extern "C" DISCORD_EXPORT void Discord_RunCallbacks(void)
401  {
402      // Note on some weirdness: internally we might connect, get other signals, disconnect any number
403      // of times inbetween calls here. Externally, we want the sequence to seem sane, so any other
404      // signals are book-ended by calls to ready and disconnect.
405  
406      if (!Connection) {
407          return;
408      }
409  
410      bool wasDisconnected = WasJustDisconnected.exchange(false);
411      bool isConnected = Connection->IsOpen();
412  
413      if (isConnected) {
414          // if we are connected, disconnect cb first
415          std::lock_guard<std::mutex> guard(HandlerMutex);
416          if (wasDisconnected && Handlers.disconnected) {
417              Handlers.disconnected(LastDisconnectErrorCode, LastDisconnectErrorMessage);
418          }
419      }
420  
421      if (WasJustConnected.exchange(false)) {
422          std::lock_guard<std::mutex> guard(HandlerMutex);
423          if (Handlers.ready) {
424              DiscordUser du{connectedUser.userId,
425                             connectedUser.username,
426                             connectedUser.discriminator,
427                             connectedUser.avatar};
428              Handlers.ready(&du);
429          }
430      }
431  
432      if (GotErrorMessage.exchange(false)) {
433          std::lock_guard<std::mutex> guard(HandlerMutex);
434          if (Handlers.errored) {
435              Handlers.errored(LastErrorCode, LastErrorMessage);
436          }
437      }
438  
439      if (WasJoinGame.exchange(false)) {
440          std::lock_guard<std::mutex> guard(HandlerMutex);
441          if (Handlers.joinGame) {
442              Handlers.joinGame(JoinGameSecret);
443          }
444      }
445  
446      if (WasSpectateGame.exchange(false)) {
447          std::lock_guard<std::mutex> guard(HandlerMutex);
448          if (Handlers.spectateGame) {
449              Handlers.spectateGame(SpectateGameSecret);
450          }
451      }
452  
453      // Right now this batches up any requests and sends them all in a burst; I could imagine a world
454      // where the implementer would rather sequentially accept/reject each one before the next invite
455      // is sent. I left it this way because I could also imagine wanting to process these all and
456      // maybe show them in one common dialog and/or start fetching the avatars in parallel, and if
457      // not it should be trivial for the implementer to make a queue themselves.
458      while (JoinAskQueue.HavePendingSends()) {
459          auto req = JoinAskQueue.GetNextSendMessage();
460          {
461              std::lock_guard<std::mutex> guard(HandlerMutex);
462              if (Handlers.joinRequest) {
463                  DiscordUser du{req->userId, req->username, req->discriminator, req->avatar};
464                  Handlers.joinRequest(&du);
465              }
466          }
467          JoinAskQueue.CommitSend();
468      }
469  
470      if (!isConnected) {
471          // if we are not connected, disconnect message last
472          std::lock_guard<std::mutex> guard(HandlerMutex);
473          if (wasDisconnected && Handlers.disconnected) {
474              Handlers.disconnected(LastDisconnectErrorCode, LastDisconnectErrorMessage);
475          }
476      }
477  }
478  
479  extern "C" DISCORD_EXPORT void Discord_UpdateHandlers(DiscordEventHandlers* newHandlers)
480  {
481      if (newHandlers) {
482  #define HANDLE_EVENT_REGISTRATION(handler_name, event)              \
483      if (!Handlers.handler_name && newHandlers->handler_name) {      \
484          RegisterForEvent(event);                                    \
485      }                                                               \
486      else if (Handlers.handler_name && !newHandlers->handler_name) { \
487          DeregisterForEvent(event);                                  \
488      }
489  
490          std::lock_guard<std::mutex> guard(HandlerMutex);
491          HANDLE_EVENT_REGISTRATION(joinGame, "ACTIVITY_JOIN")
492          HANDLE_EVENT_REGISTRATION(spectateGame, "ACTIVITY_SPECTATE")
493          HANDLE_EVENT_REGISTRATION(joinRequest, "ACTIVITY_JOIN_REQUEST")
494  
495  #undef HANDLE_EVENT_REGISTRATION
496  
497          Handlers = *newHandlers;
498      }
499      else {
500          std::lock_guard<std::mutex> guard(HandlerMutex);
501          Handlers = {};
502      }
503      return;
504  }