/ src / util / sock.cpp
sock.cpp
  1  // Copyright (c) 2020-present The Bitcoin Core developers
  2  // Distributed under the MIT software license, see the accompanying
  3  // file COPYING or http://www.opensource.org/licenses/mit-license.php.
  4  
  5  #include <util/sock.h>
  6  
  7  #include <common/system.h>
  8  #include <compat/compat.h>
  9  #include <span.h>
 10  #include <tinyformat.h>
 11  #include <util/log.h>
 12  #include <util/syserror.h>
 13  #include <util/threadinterrupt.h>
 14  #include <util/time.h>
 15  
 16  #include <memory>
 17  #include <stdexcept>
 18  #include <string>
 19  
 20  #ifdef USE_POLL
 21  #include <poll.h>
 22  #endif
 23  
 24  static inline bool IOErrorIsPermanent(int err)
 25  {
 26      return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
 27  }
 28  
 29  Sock::Sock(SOCKET s) : m_socket(s) {}
 30  
 31  Sock::Sock(Sock&& other)
 32  {
 33      m_socket = other.m_socket;
 34      other.m_socket = INVALID_SOCKET;
 35  }
 36  
 37  Sock::~Sock() { Close(); }
 38  
 39  Sock& Sock::operator=(Sock&& other)
 40  {
 41      Close();
 42      m_socket = other.m_socket;
 43      other.m_socket = INVALID_SOCKET;
 44      return *this;
 45  }
 46  
 47  ssize_t Sock::Send(const void* data, size_t len, int flags) const
 48  {
 49      return send(m_socket, static_cast<const char*>(data), len, flags);
 50  }
 51  
 52  ssize_t Sock::Recv(void* buf, size_t len, int flags) const
 53  {
 54      return recv(m_socket, static_cast<char*>(buf), len, flags);
 55  }
 56  
 57  int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
 58  {
 59      return connect(m_socket, addr, addr_len);
 60  }
 61  
 62  int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
 63  {
 64      return bind(m_socket, addr, addr_len);
 65  }
 66  
 67  int Sock::Listen(int backlog) const
 68  {
 69      return listen(m_socket, backlog);
 70  }
 71  
 72  std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
 73  {
 74  #ifdef WIN32
 75      static constexpr auto ERR = INVALID_SOCKET;
 76  #else
 77      static constexpr auto ERR = SOCKET_ERROR;
 78  #endif
 79  
 80      std::unique_ptr<Sock> sock;
 81  
 82      const auto socket = accept(m_socket, addr, addr_len);
 83      if (socket != ERR) {
 84          try {
 85              sock = std::make_unique<Sock>(socket);
 86          } catch (const std::exception&) {
 87  #ifdef WIN32
 88              closesocket(socket);
 89  #else
 90              close(socket);
 91  #endif
 92          }
 93      }
 94  
 95      return sock;
 96  }
 97  
 98  int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
 99  {
100      return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
101  }
102  
103  int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
104  {
105      return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
106  }
107  
108  int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
109  {
110      return getsockname(m_socket, name, name_len);
111  }
112  
113  bool Sock::SetNonBlocking() const
114  {
115  #ifdef WIN32
116      u_long on{1};
117      if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
118          return false;
119      }
120  #else
121      const int flags{fcntl(m_socket, F_GETFL, 0)};
122      if (flags == SOCKET_ERROR) {
123          return false;
124      }
125      if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
126          return false;
127      }
128  #endif
129      return true;
130  }
131  
132  bool Sock::IsSelectable() const
133  {
134  #if defined(USE_POLL) || defined(WIN32)
135      return true;
136  #else
137      return m_socket < FD_SETSIZE;
138  #endif
139  }
140  
141  bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
142  {
143      // We need a `shared_ptr` holding `this` for `WaitMany()`, but don't want
144      // `this` to be destroyed when the `shared_ptr` goes out of scope at the
145      // end of this function.
146      // Create it with an aliasing shared_ptr that points to `this` without
147      // owning it.
148      std::shared_ptr<const Sock> shared{std::shared_ptr<const Sock>{}, this};
149  
150      EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
151  
152      if (!WaitMany(timeout, events_per_sock)) {
153          return false;
154      }
155  
156      if (occurred != nullptr) {
157          *occurred = events_per_sock.begin()->second.occurred;
158      }
159  
160      return true;
161  }
162  
163  bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
164  {
165  #ifdef USE_POLL
166      std::vector<pollfd> pfds;
167      for (const auto& [sock, events] : events_per_sock) {
168          pfds.emplace_back();
169          auto& pfd = pfds.back();
170          pfd.fd = sock->m_socket;
171          if (events.requested & RECV) {
172              pfd.events |= POLLIN;
173          }
174          if (events.requested & SEND) {
175              pfd.events |= POLLOUT;
176          }
177      }
178  
179      if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
180          return false;
181      }
182  
183      assert(pfds.size() == events_per_sock.size());
184      size_t i{0};
185      for (auto& [sock, events] : events_per_sock) {
186          assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
187          events.occurred = 0;
188          if (pfds[i].revents & POLLIN) {
189              events.occurred |= RECV;
190          }
191          if (pfds[i].revents & POLLOUT) {
192              events.occurred |= SEND;
193          }
194          if (pfds[i].revents & (POLLERR | POLLHUP)) {
195              events.occurred |= ERR;
196          }
197          ++i;
198      }
199  
200      return true;
201  #else
202      fd_set recv;
203      fd_set send;
204      fd_set err;
205      FD_ZERO(&recv);
206      FD_ZERO(&send);
207      FD_ZERO(&err);
208      SOCKET socket_max{0};
209  
210      for (const auto& [sock, events] : events_per_sock) {
211          if (!sock->IsSelectable()) {
212              return false;
213          }
214          const auto& s = sock->m_socket;
215          if (events.requested & RECV) {
216              FD_SET(s, &recv);
217          }
218          if (events.requested & SEND) {
219              FD_SET(s, &send);
220          }
221          FD_SET(s, &err);
222          socket_max = std::max(socket_max, s);
223      }
224  
225      timeval tv = MillisToTimeval(timeout);
226  
227      if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
228          return false;
229      }
230  
231      for (auto& [sock, events] : events_per_sock) {
232          const auto& s = sock->m_socket;
233          events.occurred = 0;
234          if (FD_ISSET(s, &recv)) {
235              events.occurred |= RECV;
236          }
237          if (FD_ISSET(s, &send)) {
238              events.occurred |= SEND;
239          }
240          if (FD_ISSET(s, &err)) {
241              events.occurred |= ERR;
242          }
243      }
244  
245      return true;
246  #endif /* USE_POLL */
247  }
248  
249  void Sock::SendComplete(std::span<const unsigned char> data,
250                          std::chrono::milliseconds timeout,
251                          CThreadInterrupt& interrupt) const
252  {
253      const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
254      size_t sent{0};
255  
256      for (;;) {
257          const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
258  
259          if (ret > 0) {
260              sent += static_cast<size_t>(ret);
261              if (sent == data.size()) {
262                  break;
263              }
264          } else {
265              const int err{WSAGetLastError()};
266              if (IOErrorIsPermanent(err)) {
267                  throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
268              }
269          }
270  
271          const auto now = GetTime<std::chrono::milliseconds>();
272  
273          if (now >= deadline) {
274              throw std::runtime_error(strprintf(
275                  "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
276          }
277  
278          if (interrupt) {
279              throw std::runtime_error(strprintf(
280                  "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
281          }
282  
283          // Wait for a short while (or the socket to become ready for sending) before retrying
284          // if nothing was sent.
285          const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
286          (void)Wait(wait_time, SEND);
287      }
288  }
289  
290  void Sock::SendComplete(std::span<const char> data,
291                          std::chrono::milliseconds timeout,
292                          CThreadInterrupt& interrupt) const
293  {
294      SendComplete(MakeUCharSpan(data), timeout, interrupt);
295  }
296  
297  std::string Sock::RecvUntilTerminator(uint8_t terminator,
298                                        std::chrono::milliseconds timeout,
299                                        CThreadInterrupt& interrupt,
300                                        size_t max_data) const
301  {
302      const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
303      std::string data;
304      bool terminator_found{false};
305  
306      // We must not consume any bytes past the terminator from the socket.
307      // One option is to read one byte at a time and check if we have read a terminator.
308      // However that is very slow. Instead, we peek at what is in the socket and only read
309      // as many bytes as possible without crossing the terminator.
310      // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
311      // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
312      // at a time is about 50 times slower.
313  
314      for (;;) {
315          if (data.size() >= max_data) {
316              throw std::runtime_error(
317                  strprintf("Received too many bytes without a terminator (%u)", data.size()));
318          }
319  
320          char buf[512];
321  
322          const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
323  
324          switch (peek_ret) {
325          case -1: {
326              const int err{WSAGetLastError()};
327              if (IOErrorIsPermanent(err)) {
328                  throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
329              }
330              break;
331          }
332          case 0:
333              throw std::runtime_error("Connection unexpectedly closed by peer");
334          default:
335              auto end = buf + peek_ret;
336              auto terminator_pos = std::find(buf, end, terminator);
337              terminator_found = terminator_pos != end;
338  
339              const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
340                                                      static_cast<size_t>(peek_ret)};
341  
342              const ssize_t read_ret{Recv(buf, try_len, 0)};
343  
344              if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
345                  throw std::runtime_error(
346                      strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
347                                "peek claimed %u bytes are available",
348                                read_ret, try_len, peek_ret));
349              }
350  
351              // Don't include the terminator in the output.
352              const size_t append_len{terminator_found ? try_len - 1 : try_len};
353  
354              data.append(buf, buf + append_len);
355  
356              if (terminator_found) {
357                  return data;
358              }
359          }
360  
361          const auto now = GetTime<std::chrono::milliseconds>();
362  
363          if (now >= deadline) {
364              throw std::runtime_error(strprintf(
365                  "Receive timeout (received %u bytes without terminator before that)", data.size()));
366          }
367  
368          if (interrupt) {
369              throw std::runtime_error(strprintf(
370                  "Receive interrupted (received %u bytes without terminator before that)",
371                  data.size()));
372          }
373  
374          // Wait for a short while (or the socket to become ready for reading) before retrying.
375          const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
376          (void)Wait(wait_time, RECV);
377      }
378  }
379  
380  bool Sock::IsConnected(std::string& errmsg) const
381  {
382      if (m_socket == INVALID_SOCKET) {
383          errmsg = "not connected";
384          return false;
385      }
386  
387      char c;
388      switch (Recv(&c, sizeof(c), MSG_PEEK)) {
389      case -1: {
390          const int err = WSAGetLastError();
391          if (IOErrorIsPermanent(err)) {
392              errmsg = NetworkErrorString(err);
393              return false;
394          }
395          return true;
396      }
397      case 0:
398          errmsg = "closed";
399          return false;
400      default:
401          return true;
402      }
403  }
404  
405  void Sock::Close()
406  {
407      if (m_socket == INVALID_SOCKET) {
408          return;
409      }
410  #ifdef WIN32
411      int ret = closesocket(m_socket);
412  #else
413      int ret = close(m_socket);
414  #endif
415      if (ret) {
416          LogWarning("Error closing socket %d: %s", m_socket, NetworkErrorString(WSAGetLastError()));
417      }
418      m_socket = INVALID_SOCKET;
419  }
420  
421  bool Sock::operator==(SOCKET s) const
422  {
423      return m_socket == s;
424  };
425  
426  std::string NetworkErrorString(int err)
427  {
428  #if defined(WIN32)
429      return Win32ErrorString(err);
430  #else
431      // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
432      return SysErrorString(err);
433  #endif
434  }