/ src / util / sock.cpp
sock.cpp
  1  // Copyright (c) 2020-present The Bitcoin Core developers
  2  // Distributed under the MIT software license, see the accompanying
  3  // file COPYING or http://www.opensource.org/licenses/mit-license.php.
  4  
  5  #include <util/sock.h>
  6  
  7  #include <compat/compat.h>
  8  #include <span.h>
  9  #include <tinyformat.h>
 10  #include <util/check.h>
 11  #include <util/log.h>
 12  #include <util/syserror.h>
 13  #include <util/threadinterrupt.h>
 14  #include <util/time.h>
 15  
 16  #include <algorithm>
 17  #include <compare>
 18  #include <exception>
 19  #include <memory>
 20  #include <stdexcept>
 21  #include <string>
 22  #include <utility>
 23  #include <vector>
 24  
 25  #ifdef USE_POLL
 26  #include <poll.h>
 27  #endif
 28  
 29  static inline bool IOErrorIsPermanent(int err)
 30  {
 31      return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
 32  }
 33  
 34  Sock::Sock(SOCKET s) : m_socket(s) {}
 35  
 36  Sock::Sock(Sock&& other)
 37  {
 38      m_socket = other.m_socket;
 39      other.m_socket = INVALID_SOCKET;
 40  }
 41  
 42  Sock::~Sock() { Close(); }
 43  
 44  Sock& Sock::operator=(Sock&& other)
 45  {
 46      Close();
 47      m_socket = other.m_socket;
 48      other.m_socket = INVALID_SOCKET;
 49      return *this;
 50  }
 51  
 52  ssize_t Sock::Send(const void* data, size_t len, int flags) const
 53  {
 54      return send(m_socket, static_cast<const char*>(data), len, flags);
 55  }
 56  
 57  ssize_t Sock::Recv(void* buf, size_t len, int flags) const
 58  {
 59      return recv(m_socket, static_cast<char*>(buf), len, flags);
 60  }
 61  
 62  int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
 63  {
 64      return connect(m_socket, addr, addr_len);
 65  }
 66  
 67  int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
 68  {
 69      return bind(m_socket, addr, addr_len);
 70  }
 71  
 72  int Sock::Listen(int backlog) const
 73  {
 74      return listen(m_socket, backlog);
 75  }
 76  
 77  std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
 78  {
 79  #ifdef WIN32
 80      static constexpr auto ERR = INVALID_SOCKET;
 81  #else
 82      static constexpr auto ERR = SOCKET_ERROR;
 83  #endif
 84  
 85      std::unique_ptr<Sock> sock;
 86  
 87      const auto socket = accept(m_socket, addr, addr_len);
 88      if (socket != ERR) {
 89          try {
 90              sock = std::make_unique<Sock>(socket);
 91          } catch (const std::exception&) {
 92  #ifdef WIN32
 93              closesocket(socket);
 94  #else
 95              close(socket);
 96  #endif
 97          }
 98      }
 99  
100      return sock;
101  }
102  
103  int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
104  {
105      return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
106  }
107  
108  int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
109  {
110      return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
111  }
112  
113  int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
114  {
115      return getsockname(m_socket, name, name_len);
116  }
117  
118  bool Sock::SetNonBlocking() const
119  {
120  #ifdef WIN32
121      u_long on{1};
122      if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
123          return false;
124      }
125  #else
126      const int flags{fcntl(m_socket, F_GETFL, 0)};
127      if (flags == SOCKET_ERROR) {
128          return false;
129      }
130      if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
131          return false;
132      }
133  #endif
134      return true;
135  }
136  
137  bool Sock::IsSelectable() const
138  {
139  #if defined(USE_POLL) || defined(WIN32)
140      return true;
141  #else
142      return m_socket < FD_SETSIZE;
143  #endif
144  }
145  
146  bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
147  {
148      // We need a `shared_ptr` holding `this` for `WaitMany()`, but don't want
149      // `this` to be destroyed when the `shared_ptr` goes out of scope at the
150      // end of this function.
151      // Create it with an aliasing shared_ptr that points to `this` without
152      // owning it.
153      std::shared_ptr<const Sock> shared{std::shared_ptr<const Sock>{}, this};
154  
155      EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
156  
157      if (!WaitMany(timeout, events_per_sock)) {
158          return false;
159      }
160  
161      if (occurred != nullptr) {
162          *occurred = events_per_sock.begin()->second.occurred;
163      }
164  
165      return true;
166  }
167  
168  bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
169  {
170  #ifdef USE_POLL
171      std::vector<pollfd> pfds;
172      for (const auto& [sock, events] : events_per_sock) {
173          pfds.emplace_back();
174          auto& pfd = pfds.back();
175          pfd.fd = sock->m_socket;
176          if (events.requested & RECV) {
177              pfd.events |= POLLIN;
178          }
179          if (events.requested & SEND) {
180              pfd.events |= POLLOUT;
181          }
182      }
183  
184      if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
185          return false;
186      }
187  
188      assert(pfds.size() == events_per_sock.size());
189      size_t i{0};
190      for (auto& [sock, events] : events_per_sock) {
191          assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
192          events.occurred = 0;
193          if (pfds[i].revents & POLLIN) {
194              events.occurred |= RECV;
195          }
196          if (pfds[i].revents & POLLOUT) {
197              events.occurred |= SEND;
198          }
199          if (pfds[i].revents & (POLLERR | POLLHUP)) {
200              events.occurred |= ERR;
201          }
202          ++i;
203      }
204  
205      return true;
206  #else
207      fd_set recv;
208      fd_set send;
209      fd_set err;
210      FD_ZERO(&recv);
211      FD_ZERO(&send);
212      FD_ZERO(&err);
213      SOCKET socket_max{0};
214  
215      for (const auto& [sock, events] : events_per_sock) {
216          if (!sock->IsSelectable()) {
217              return false;
218          }
219          const auto& s = sock->m_socket;
220          if (events.requested & RECV) {
221              FD_SET(s, &recv);
222          }
223          if (events.requested & SEND) {
224              FD_SET(s, &send);
225          }
226          FD_SET(s, &err);
227          socket_max = std::max(socket_max, s);
228      }
229  
230      timeval tv = MillisToTimeval(timeout);
231  
232      if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
233          return false;
234      }
235  
236      for (auto& [sock, events] : events_per_sock) {
237          const auto& s = sock->m_socket;
238          events.occurred = 0;
239          if (FD_ISSET(s, &recv)) {
240              events.occurred |= RECV;
241          }
242          if (FD_ISSET(s, &send)) {
243              events.occurred |= SEND;
244          }
245          if (FD_ISSET(s, &err)) {
246              events.occurred |= ERR;
247          }
248      }
249  
250      return true;
251  #endif /* USE_POLL */
252  }
253  
254  void Sock::SendComplete(std::span<const unsigned char> data,
255                          std::chrono::milliseconds timeout,
256                          CThreadInterrupt& interrupt) const
257  {
258      const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
259      size_t sent{0};
260  
261      for (;;) {
262          const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
263  
264          if (ret > 0) {
265              sent += static_cast<size_t>(ret);
266              if (sent == data.size()) {
267                  break;
268              }
269          } else {
270              const int err{WSAGetLastError()};
271              if (IOErrorIsPermanent(err)) {
272                  throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
273              }
274          }
275  
276          const auto now = GetTime<std::chrono::milliseconds>();
277  
278          if (now >= deadline) {
279              throw std::runtime_error(strprintf(
280                  "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
281          }
282  
283          if (interrupt) {
284              throw std::runtime_error(strprintf(
285                  "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
286          }
287  
288          // Wait for a short while (or the socket to become ready for sending) before retrying
289          // if nothing was sent.
290          const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
291          (void)Wait(wait_time, SEND);
292      }
293  }
294  
295  void Sock::SendComplete(std::span<const char> data,
296                          std::chrono::milliseconds timeout,
297                          CThreadInterrupt& interrupt) const
298  {
299      SendComplete(MakeUCharSpan(data), timeout, interrupt);
300  }
301  
302  std::string Sock::RecvUntilTerminator(uint8_t terminator,
303                                        std::chrono::milliseconds timeout,
304                                        CThreadInterrupt& interrupt,
305                                        size_t max_data) const
306  {
307      const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
308      std::string data;
309      bool terminator_found{false};
310  
311      // We must not consume any bytes past the terminator from the socket.
312      // One option is to read one byte at a time and check if we have read a terminator.
313      // However that is very slow. Instead, we peek at what is in the socket and only read
314      // as many bytes as possible without crossing the terminator.
315      // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
316      // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
317      // at a time is about 50 times slower.
318  
319      for (;;) {
320          if (data.size() >= max_data) {
321              throw std::runtime_error(
322                  strprintf("Received too many bytes without a terminator (%u)", data.size()));
323          }
324  
325          char buf[512];
326  
327          const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
328  
329          switch (peek_ret) {
330          case -1: {
331              const int err{WSAGetLastError()};
332              if (IOErrorIsPermanent(err)) {
333                  throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
334              }
335              break;
336          }
337          case 0:
338              throw std::runtime_error("Connection unexpectedly closed by peer");
339          default:
340              auto end = buf + peek_ret;
341              auto terminator_pos = std::find(buf, end, terminator);
342              terminator_found = terminator_pos != end;
343  
344              const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
345                                                      static_cast<size_t>(peek_ret)};
346  
347              const ssize_t read_ret{Recv(buf, try_len, 0)};
348  
349              if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
350                  throw std::runtime_error(
351                      strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
352                                "peek claimed %u bytes are available",
353                                read_ret, try_len, peek_ret));
354              }
355  
356              // Don't include the terminator in the output.
357              const size_t append_len{terminator_found ? try_len - 1 : try_len};
358  
359              data.append(buf, buf + append_len);
360  
361              if (terminator_found) {
362                  return data;
363              }
364          }
365  
366          const auto now = GetTime<std::chrono::milliseconds>();
367  
368          if (now >= deadline) {
369              throw std::runtime_error(strprintf(
370                  "Receive timeout (received %u bytes without terminator before that)", data.size()));
371          }
372  
373          if (interrupt) {
374              throw std::runtime_error(strprintf(
375                  "Receive interrupted (received %u bytes without terminator before that)",
376                  data.size()));
377          }
378  
379          // Wait for a short while (or the socket to become ready for reading) before retrying.
380          const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
381          (void)Wait(wait_time, RECV);
382      }
383  }
384  
385  bool Sock::IsConnected(std::string& errmsg) const
386  {
387      if (m_socket == INVALID_SOCKET) {
388          errmsg = "not connected";
389          return false;
390      }
391  
392      char c;
393      switch (Recv(&c, sizeof(c), MSG_PEEK)) {
394      case -1: {
395          const int err = WSAGetLastError();
396          if (IOErrorIsPermanent(err)) {
397              errmsg = NetworkErrorString(err);
398              return false;
399          }
400          return true;
401      }
402      case 0:
403          errmsg = "closed";
404          return false;
405      default:
406          return true;
407      }
408  }
409  
410  void Sock::Close()
411  {
412      if (m_socket == INVALID_SOCKET) {
413          return;
414      }
415  #ifdef WIN32
416      int ret = closesocket(m_socket);
417  #else
418      int ret = close(m_socket);
419  #endif
420      if (ret) {
421          LogWarning("Error closing socket %d: %s", m_socket, NetworkErrorString(WSAGetLastError()));
422      }
423      m_socket = INVALID_SOCKET;
424  }
425  
426  bool Sock::operator==(SOCKET s) const
427  {
428      return m_socket == s;
429  };
430  
431  std::string NetworkErrorString(int err)
432  {
433  #if defined(WIN32)
434      return Win32ErrorString(err);
435  #else
436      // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
437      return SysErrorString(err);
438  #endif
439  }