/ src / util / sock.cpp
sock.cpp
  1  // Copyright (c) 2020-2022 The Bitcoin Core developers
  2  // Distributed under the MIT software license, see the accompanying
  3  // file COPYING or http://www.opensource.org/licenses/mit-license.php.
  4  
  5  #include <common/system.h>
  6  #include <compat/compat.h>
  7  #include <logging.h>
  8  #include <tinyformat.h>
  9  #include <util/sock.h>
 10  #include <util/syserror.h>
 11  #include <util/threadinterrupt.h>
 12  #include <util/time.h>
 13  
 14  #include <memory>
 15  #include <stdexcept>
 16  #include <string>
 17  
 18  #ifdef USE_POLL
 19  #include <poll.h>
 20  #endif
 21  
 22  static inline bool IOErrorIsPermanent(int err)
 23  {
 24      return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
 25  }
 26  
 27  Sock::Sock(SOCKET s) : m_socket(s) {}
 28  
 29  Sock::Sock(Sock&& other)
 30  {
 31      m_socket = other.m_socket;
 32      other.m_socket = INVALID_SOCKET;
 33  }
 34  
 35  Sock::~Sock() { Close(); }
 36  
 37  Sock& Sock::operator=(Sock&& other)
 38  {
 39      Close();
 40      m_socket = other.m_socket;
 41      other.m_socket = INVALID_SOCKET;
 42      return *this;
 43  }
 44  
 45  ssize_t Sock::Send(const void* data, size_t len, int flags) const
 46  {
 47      return send(m_socket, static_cast<const char*>(data), len, flags);
 48  }
 49  
 50  ssize_t Sock::Recv(void* buf, size_t len, int flags) const
 51  {
 52      return recv(m_socket, static_cast<char*>(buf), len, flags);
 53  }
 54  
 55  int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
 56  {
 57      return connect(m_socket, addr, addr_len);
 58  }
 59  
 60  int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
 61  {
 62      return bind(m_socket, addr, addr_len);
 63  }
 64  
 65  int Sock::Listen(int backlog) const
 66  {
 67      return listen(m_socket, backlog);
 68  }
 69  
 70  std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
 71  {
 72  #ifdef WIN32
 73      static constexpr auto ERR = INVALID_SOCKET;
 74  #else
 75      static constexpr auto ERR = SOCKET_ERROR;
 76  #endif
 77  
 78      std::unique_ptr<Sock> sock;
 79  
 80      const auto socket = accept(m_socket, addr, addr_len);
 81      if (socket != ERR) {
 82          try {
 83              sock = std::make_unique<Sock>(socket);
 84          } catch (const std::exception&) {
 85  #ifdef WIN32
 86              closesocket(socket);
 87  #else
 88              close(socket);
 89  #endif
 90          }
 91      }
 92  
 93      return sock;
 94  }
 95  
 96  int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
 97  {
 98      return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
 99  }
100  
101  int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
102  {
103      return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
104  }
105  
106  int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
107  {
108      return getsockname(m_socket, name, name_len);
109  }
110  
111  bool Sock::SetNonBlocking() const
112  {
113  #ifdef WIN32
114      u_long on{1};
115      if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
116          return false;
117      }
118  #else
119      const int flags{fcntl(m_socket, F_GETFL, 0)};
120      if (flags == SOCKET_ERROR) {
121          return false;
122      }
123      if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
124          return false;
125      }
126  #endif
127      return true;
128  }
129  
130  bool Sock::IsSelectable() const
131  {
132  #if defined(USE_POLL) || defined(WIN32)
133      return true;
134  #else
135      return m_socket < FD_SETSIZE;
136  #endif
137  }
138  
139  bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
140  {
141      // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want
142      // `this` to be destroyed when the `shared_ptr` goes out of scope at the
143      // end of this function. Create it with a custom noop deleter.
144      std::shared_ptr<const Sock> shared{this, [](const Sock*) {}};
145  
146      EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
147  
148      if (!WaitMany(timeout, events_per_sock)) {
149          return false;
150      }
151  
152      if (occurred != nullptr) {
153          *occurred = events_per_sock.begin()->second.occurred;
154      }
155  
156      return true;
157  }
158  
159  bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
160  {
161  #ifdef USE_POLL
162      std::vector<pollfd> pfds;
163      for (const auto& [sock, events] : events_per_sock) {
164          pfds.emplace_back();
165          auto& pfd = pfds.back();
166          pfd.fd = sock->m_socket;
167          if (events.requested & RECV) {
168              pfd.events |= POLLIN;
169          }
170          if (events.requested & SEND) {
171              pfd.events |= POLLOUT;
172          }
173      }
174  
175      if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
176          return false;
177      }
178  
179      assert(pfds.size() == events_per_sock.size());
180      size_t i{0};
181      for (auto& [sock, events] : events_per_sock) {
182          assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
183          events.occurred = 0;
184          if (pfds[i].revents & POLLIN) {
185              events.occurred |= RECV;
186          }
187          if (pfds[i].revents & POLLOUT) {
188              events.occurred |= SEND;
189          }
190          if (pfds[i].revents & (POLLERR | POLLHUP)) {
191              events.occurred |= ERR;
192          }
193          ++i;
194      }
195  
196      return true;
197  #else
198      fd_set recv;
199      fd_set send;
200      fd_set err;
201      FD_ZERO(&recv);
202      FD_ZERO(&send);
203      FD_ZERO(&err);
204      SOCKET socket_max{0};
205  
206      for (const auto& [sock, events] : events_per_sock) {
207          if (!sock->IsSelectable()) {
208              return false;
209          }
210          const auto& s = sock->m_socket;
211          if (events.requested & RECV) {
212              FD_SET(s, &recv);
213          }
214          if (events.requested & SEND) {
215              FD_SET(s, &send);
216          }
217          FD_SET(s, &err);
218          socket_max = std::max(socket_max, s);
219      }
220  
221      timeval tv = MillisToTimeval(timeout);
222  
223      if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
224          return false;
225      }
226  
227      for (auto& [sock, events] : events_per_sock) {
228          const auto& s = sock->m_socket;
229          events.occurred = 0;
230          if (FD_ISSET(s, &recv)) {
231              events.occurred |= RECV;
232          }
233          if (FD_ISSET(s, &send)) {
234              events.occurred |= SEND;
235          }
236          if (FD_ISSET(s, &err)) {
237              events.occurred |= ERR;
238          }
239      }
240  
241      return true;
242  #endif /* USE_POLL */
243  }
244  
245  void Sock::SendComplete(Span<const unsigned char> data,
246                          std::chrono::milliseconds timeout,
247                          CThreadInterrupt& interrupt) const
248  {
249      const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
250      size_t sent{0};
251  
252      for (;;) {
253          const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
254  
255          if (ret > 0) {
256              sent += static_cast<size_t>(ret);
257              if (sent == data.size()) {
258                  break;
259              }
260          } else {
261              const int err{WSAGetLastError()};
262              if (IOErrorIsPermanent(err)) {
263                  throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
264              }
265          }
266  
267          const auto now = GetTime<std::chrono::milliseconds>();
268  
269          if (now >= deadline) {
270              throw std::runtime_error(strprintf(
271                  "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
272          }
273  
274          if (interrupt) {
275              throw std::runtime_error(strprintf(
276                  "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
277          }
278  
279          // Wait for a short while (or the socket to become ready for sending) before retrying
280          // if nothing was sent.
281          const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
282          (void)Wait(wait_time, SEND);
283      }
284  }
285  
286  void Sock::SendComplete(Span<const char> data,
287                          std::chrono::milliseconds timeout,
288                          CThreadInterrupt& interrupt) const
289  {
290      SendComplete(MakeUCharSpan(data), timeout, interrupt);
291  }
292  
293  std::string Sock::RecvUntilTerminator(uint8_t terminator,
294                                        std::chrono::milliseconds timeout,
295                                        CThreadInterrupt& interrupt,
296                                        size_t max_data) const
297  {
298      const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
299      std::string data;
300      bool terminator_found{false};
301  
302      // We must not consume any bytes past the terminator from the socket.
303      // One option is to read one byte at a time and check if we have read a terminator.
304      // However that is very slow. Instead, we peek at what is in the socket and only read
305      // as many bytes as possible without crossing the terminator.
306      // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
307      // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
308      // at a time is about 50 times slower.
309  
310      for (;;) {
311          if (data.size() >= max_data) {
312              throw std::runtime_error(
313                  strprintf("Received too many bytes without a terminator (%u)", data.size()));
314          }
315  
316          char buf[512];
317  
318          const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
319  
320          switch (peek_ret) {
321          case -1: {
322              const int err{WSAGetLastError()};
323              if (IOErrorIsPermanent(err)) {
324                  throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
325              }
326              break;
327          }
328          case 0:
329              throw std::runtime_error("Connection unexpectedly closed by peer");
330          default:
331              auto end = buf + peek_ret;
332              auto terminator_pos = std::find(buf, end, terminator);
333              terminator_found = terminator_pos != end;
334  
335              const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
336                                                      static_cast<size_t>(peek_ret)};
337  
338              const ssize_t read_ret{Recv(buf, try_len, 0)};
339  
340              if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
341                  throw std::runtime_error(
342                      strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
343                                "peek claimed %u bytes are available",
344                                read_ret, try_len, peek_ret));
345              }
346  
347              // Don't include the terminator in the output.
348              const size_t append_len{terminator_found ? try_len - 1 : try_len};
349  
350              data.append(buf, buf + append_len);
351  
352              if (terminator_found) {
353                  return data;
354              }
355          }
356  
357          const auto now = GetTime<std::chrono::milliseconds>();
358  
359          if (now >= deadline) {
360              throw std::runtime_error(strprintf(
361                  "Receive timeout (received %u bytes without terminator before that)", data.size()));
362          }
363  
364          if (interrupt) {
365              throw std::runtime_error(strprintf(
366                  "Receive interrupted (received %u bytes without terminator before that)",
367                  data.size()));
368          }
369  
370          // Wait for a short while (or the socket to become ready for reading) before retrying.
371          const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
372          (void)Wait(wait_time, RECV);
373      }
374  }
375  
376  bool Sock::IsConnected(std::string& errmsg) const
377  {
378      if (m_socket == INVALID_SOCKET) {
379          errmsg = "not connected";
380          return false;
381      }
382  
383      char c;
384      switch (Recv(&c, sizeof(c), MSG_PEEK)) {
385      case -1: {
386          const int err = WSAGetLastError();
387          if (IOErrorIsPermanent(err)) {
388              errmsg = NetworkErrorString(err);
389              return false;
390          }
391          return true;
392      }
393      case 0:
394          errmsg = "closed";
395          return false;
396      default:
397          return true;
398      }
399  }
400  
401  void Sock::Close()
402  {
403      if (m_socket == INVALID_SOCKET) {
404          return;
405      }
406  #ifdef WIN32
407      int ret = closesocket(m_socket);
408  #else
409      int ret = close(m_socket);
410  #endif
411      if (ret) {
412          LogPrintf("Error closing socket %d: %s\n", m_socket, NetworkErrorString(WSAGetLastError()));
413      }
414      m_socket = INVALID_SOCKET;
415  }
416  
417  bool Sock::operator==(SOCKET s) const
418  {
419      return m_socket == s;
420  };
421  
422  std::string NetworkErrorString(int err)
423  {
424  #if defined(WIN32)
425      return Win32ErrorString(err);
426  #else
427      // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
428      return SysErrorString(err);
429  #endif
430  }