sock.cpp
1 // Copyright (c) 2020-present The Bitcoin Core developers 2 // Distributed under the MIT software license, see the accompanying 3 // file COPYING or http://www.opensource.org/licenses/mit-license.php. 4 5 #include <util/sock.h> 6 7 #include <common/system.h> 8 #include <compat/compat.h> 9 #include <span.h> 10 #include <tinyformat.h> 11 #include <util/log.h> 12 #include <util/syserror.h> 13 #include <util/threadinterrupt.h> 14 #include <util/time.h> 15 16 #include <memory> 17 #include <stdexcept> 18 #include <string> 19 20 #ifdef USE_POLL 21 #include <poll.h> 22 #endif 23 24 static inline bool IOErrorIsPermanent(int err) 25 { 26 return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS; 27 } 28 29 Sock::Sock(SOCKET s) : m_socket(s) {} 30 31 Sock::Sock(Sock&& other) 32 { 33 m_socket = other.m_socket; 34 other.m_socket = INVALID_SOCKET; 35 } 36 37 Sock::~Sock() { Close(); } 38 39 Sock& Sock::operator=(Sock&& other) 40 { 41 Close(); 42 m_socket = other.m_socket; 43 other.m_socket = INVALID_SOCKET; 44 return *this; 45 } 46 47 ssize_t Sock::Send(const void* data, size_t len, int flags) const 48 { 49 return send(m_socket, static_cast<const char*>(data), len, flags); 50 } 51 52 ssize_t Sock::Recv(void* buf, size_t len, int flags) const 53 { 54 return recv(m_socket, static_cast<char*>(buf), len, flags); 55 } 56 57 int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const 58 { 59 return connect(m_socket, addr, addr_len); 60 } 61 62 int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const 63 { 64 return bind(m_socket, addr, addr_len); 65 } 66 67 int Sock::Listen(int backlog) const 68 { 69 return listen(m_socket, backlog); 70 } 71 72 std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const 73 { 74 #ifdef WIN32 75 static constexpr auto ERR = INVALID_SOCKET; 76 #else 77 static constexpr auto ERR = SOCKET_ERROR; 78 #endif 79 80 std::unique_ptr<Sock> sock; 81 82 const auto socket = accept(m_socket, addr, addr_len); 83 if (socket != ERR) { 84 try { 85 sock = std::make_unique<Sock>(socket); 86 } catch (const std::exception&) { 87 #ifdef WIN32 88 closesocket(socket); 89 #else 90 close(socket); 91 #endif 92 } 93 } 94 95 return sock; 96 } 97 98 int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const 99 { 100 return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len); 101 } 102 103 int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const 104 { 105 return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len); 106 } 107 108 int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const 109 { 110 return getsockname(m_socket, name, name_len); 111 } 112 113 bool Sock::SetNonBlocking() const 114 { 115 #ifdef WIN32 116 u_long on{1}; 117 if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) { 118 return false; 119 } 120 #else 121 const int flags{fcntl(m_socket, F_GETFL, 0)}; 122 if (flags == SOCKET_ERROR) { 123 return false; 124 } 125 if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) { 126 return false; 127 } 128 #endif 129 return true; 130 } 131 132 bool Sock::IsSelectable() const 133 { 134 #if defined(USE_POLL) || defined(WIN32) 135 return true; 136 #else 137 return m_socket < FD_SETSIZE; 138 #endif 139 } 140 141 bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const 142 { 143 // We need a `shared_ptr` holding `this` for `WaitMany()`, but don't want 144 // `this` to be destroyed when the `shared_ptr` goes out of scope at the 145 // end of this function. 146 // Create it with an aliasing shared_ptr that points to `this` without 147 // owning it. 148 std::shared_ptr<const Sock> shared{std::shared_ptr<const Sock>{}, this}; 149 150 EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})}; 151 152 if (!WaitMany(timeout, events_per_sock)) { 153 return false; 154 } 155 156 if (occurred != nullptr) { 157 *occurred = events_per_sock.begin()->second.occurred; 158 } 159 160 return true; 161 } 162 163 bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const 164 { 165 #ifdef USE_POLL 166 std::vector<pollfd> pfds; 167 for (const auto& [sock, events] : events_per_sock) { 168 pfds.emplace_back(); 169 auto& pfd = pfds.back(); 170 pfd.fd = sock->m_socket; 171 if (events.requested & RECV) { 172 pfd.events |= POLLIN; 173 } 174 if (events.requested & SEND) { 175 pfd.events |= POLLOUT; 176 } 177 } 178 179 if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) { 180 return false; 181 } 182 183 assert(pfds.size() == events_per_sock.size()); 184 size_t i{0}; 185 for (auto& [sock, events] : events_per_sock) { 186 assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd)); 187 events.occurred = 0; 188 if (pfds[i].revents & POLLIN) { 189 events.occurred |= RECV; 190 } 191 if (pfds[i].revents & POLLOUT) { 192 events.occurred |= SEND; 193 } 194 if (pfds[i].revents & (POLLERR | POLLHUP)) { 195 events.occurred |= ERR; 196 } 197 ++i; 198 } 199 200 return true; 201 #else 202 fd_set recv; 203 fd_set send; 204 fd_set err; 205 FD_ZERO(&recv); 206 FD_ZERO(&send); 207 FD_ZERO(&err); 208 SOCKET socket_max{0}; 209 210 for (const auto& [sock, events] : events_per_sock) { 211 if (!sock->IsSelectable()) { 212 return false; 213 } 214 const auto& s = sock->m_socket; 215 if (events.requested & RECV) { 216 FD_SET(s, &recv); 217 } 218 if (events.requested & SEND) { 219 FD_SET(s, &send); 220 } 221 FD_SET(s, &err); 222 socket_max = std::max(socket_max, s); 223 } 224 225 timeval tv = MillisToTimeval(timeout); 226 227 if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) { 228 return false; 229 } 230 231 for (auto& [sock, events] : events_per_sock) { 232 const auto& s = sock->m_socket; 233 events.occurred = 0; 234 if (FD_ISSET(s, &recv)) { 235 events.occurred |= RECV; 236 } 237 if (FD_ISSET(s, &send)) { 238 events.occurred |= SEND; 239 } 240 if (FD_ISSET(s, &err)) { 241 events.occurred |= ERR; 242 } 243 } 244 245 return true; 246 #endif /* USE_POLL */ 247 } 248 249 void Sock::SendComplete(std::span<const unsigned char> data, 250 std::chrono::milliseconds timeout, 251 CThreadInterrupt& interrupt) const 252 { 253 const auto deadline = GetTime<std::chrono::milliseconds>() + timeout; 254 size_t sent{0}; 255 256 for (;;) { 257 const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)}; 258 259 if (ret > 0) { 260 sent += static_cast<size_t>(ret); 261 if (sent == data.size()) { 262 break; 263 } 264 } else { 265 const int err{WSAGetLastError()}; 266 if (IOErrorIsPermanent(err)) { 267 throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err))); 268 } 269 } 270 271 const auto now = GetTime<std::chrono::milliseconds>(); 272 273 if (now >= deadline) { 274 throw std::runtime_error(strprintf( 275 "Send timeout (sent only %u of %u bytes before that)", sent, data.size())); 276 } 277 278 if (interrupt) { 279 throw std::runtime_error(strprintf( 280 "Send interrupted (sent only %u of %u bytes before that)", sent, data.size())); 281 } 282 283 // Wait for a short while (or the socket to become ready for sending) before retrying 284 // if nothing was sent. 285 const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO}); 286 (void)Wait(wait_time, SEND); 287 } 288 } 289 290 void Sock::SendComplete(std::span<const char> data, 291 std::chrono::milliseconds timeout, 292 CThreadInterrupt& interrupt) const 293 { 294 SendComplete(MakeUCharSpan(data), timeout, interrupt); 295 } 296 297 std::string Sock::RecvUntilTerminator(uint8_t terminator, 298 std::chrono::milliseconds timeout, 299 CThreadInterrupt& interrupt, 300 size_t max_data) const 301 { 302 const auto deadline = GetTime<std::chrono::milliseconds>() + timeout; 303 std::string data; 304 bool terminator_found{false}; 305 306 // We must not consume any bytes past the terminator from the socket. 307 // One option is to read one byte at a time and check if we have read a terminator. 308 // However that is very slow. Instead, we peek at what is in the socket and only read 309 // as many bytes as possible without crossing the terminator. 310 // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read 311 // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte 312 // at a time is about 50 times slower. 313 314 for (;;) { 315 if (data.size() >= max_data) { 316 throw std::runtime_error( 317 strprintf("Received too many bytes without a terminator (%u)", data.size())); 318 } 319 320 char buf[512]; 321 322 const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)}; 323 324 switch (peek_ret) { 325 case -1: { 326 const int err{WSAGetLastError()}; 327 if (IOErrorIsPermanent(err)) { 328 throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err))); 329 } 330 break; 331 } 332 case 0: 333 throw std::runtime_error("Connection unexpectedly closed by peer"); 334 default: 335 auto end = buf + peek_ret; 336 auto terminator_pos = std::find(buf, end, terminator); 337 terminator_found = terminator_pos != end; 338 339 const size_t try_len{terminator_found ? terminator_pos - buf + 1 : 340 static_cast<size_t>(peek_ret)}; 341 342 const ssize_t read_ret{Recv(buf, try_len, 0)}; 343 344 if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) { 345 throw std::runtime_error( 346 strprintf("recv() returned %u bytes on attempt to read %u bytes but previous " 347 "peek claimed %u bytes are available", 348 read_ret, try_len, peek_ret)); 349 } 350 351 // Don't include the terminator in the output. 352 const size_t append_len{terminator_found ? try_len - 1 : try_len}; 353 354 data.append(buf, buf + append_len); 355 356 if (terminator_found) { 357 return data; 358 } 359 } 360 361 const auto now = GetTime<std::chrono::milliseconds>(); 362 363 if (now >= deadline) { 364 throw std::runtime_error(strprintf( 365 "Receive timeout (received %u bytes without terminator before that)", data.size())); 366 } 367 368 if (interrupt) { 369 throw std::runtime_error(strprintf( 370 "Receive interrupted (received %u bytes without terminator before that)", 371 data.size())); 372 } 373 374 // Wait for a short while (or the socket to become ready for reading) before retrying. 375 const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO}); 376 (void)Wait(wait_time, RECV); 377 } 378 } 379 380 bool Sock::IsConnected(std::string& errmsg) const 381 { 382 if (m_socket == INVALID_SOCKET) { 383 errmsg = "not connected"; 384 return false; 385 } 386 387 char c; 388 switch (Recv(&c, sizeof(c), MSG_PEEK)) { 389 case -1: { 390 const int err = WSAGetLastError(); 391 if (IOErrorIsPermanent(err)) { 392 errmsg = NetworkErrorString(err); 393 return false; 394 } 395 return true; 396 } 397 case 0: 398 errmsg = "closed"; 399 return false; 400 default: 401 return true; 402 } 403 } 404 405 void Sock::Close() 406 { 407 if (m_socket == INVALID_SOCKET) { 408 return; 409 } 410 #ifdef WIN32 411 int ret = closesocket(m_socket); 412 #else 413 int ret = close(m_socket); 414 #endif 415 if (ret) { 416 LogWarning("Error closing socket %d: %s", m_socket, NetworkErrorString(WSAGetLastError())); 417 } 418 m_socket = INVALID_SOCKET; 419 } 420 421 bool Sock::operator==(SOCKET s) const 422 { 423 return m_socket == s; 424 }; 425 426 std::string NetworkErrorString(int err) 427 { 428 #if defined(WIN32) 429 return Win32ErrorString(err); 430 #else 431 // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString. 432 return SysErrorString(err); 433 #endif 434 }