sock.cpp
1 // Copyright (c) 2020-present The Bitcoin Core developers 2 // Distributed under the MIT software license, see the accompanying 3 // file COPYING or http://www.opensource.org/licenses/mit-license.php. 4 5 #include <util/sock.h> 6 7 #include <compat/compat.h> 8 #include <span.h> 9 #include <tinyformat.h> 10 #include <util/check.h> 11 #include <util/log.h> 12 #include <util/syserror.h> 13 #include <util/threadinterrupt.h> 14 #include <util/time.h> 15 16 #include <algorithm> 17 #include <compare> 18 #include <exception> 19 #include <memory> 20 #include <stdexcept> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #ifdef USE_POLL 26 #include <poll.h> 27 #endif 28 29 static inline bool IOErrorIsPermanent(int err) 30 { 31 return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS; 32 } 33 34 Sock::Sock(SOCKET s) : m_socket(s) {} 35 36 Sock::Sock(Sock&& other) 37 { 38 m_socket = other.m_socket; 39 other.m_socket = INVALID_SOCKET; 40 } 41 42 Sock::~Sock() { Close(); } 43 44 Sock& Sock::operator=(Sock&& other) 45 { 46 Close(); 47 m_socket = other.m_socket; 48 other.m_socket = INVALID_SOCKET; 49 return *this; 50 } 51 52 ssize_t Sock::Send(const void* data, size_t len, int flags) const 53 { 54 return send(m_socket, static_cast<const char*>(data), len, flags); 55 } 56 57 ssize_t Sock::Recv(void* buf, size_t len, int flags) const 58 { 59 return recv(m_socket, static_cast<char*>(buf), len, flags); 60 } 61 62 int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const 63 { 64 return connect(m_socket, addr, addr_len); 65 } 66 67 int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const 68 { 69 return bind(m_socket, addr, addr_len); 70 } 71 72 int Sock::Listen(int backlog) const 73 { 74 return listen(m_socket, backlog); 75 } 76 77 std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const 78 { 79 #ifdef WIN32 80 static constexpr auto ERR = INVALID_SOCKET; 81 #else 82 static constexpr auto ERR = SOCKET_ERROR; 83 #endif 84 85 std::unique_ptr<Sock> sock; 86 87 const auto socket = accept(m_socket, addr, addr_len); 88 if (socket != ERR) { 89 try { 90 sock = std::make_unique<Sock>(socket); 91 } catch (const std::exception&) { 92 #ifdef WIN32 93 closesocket(socket); 94 #else 95 close(socket); 96 #endif 97 } 98 } 99 100 return sock; 101 } 102 103 int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const 104 { 105 return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len); 106 } 107 108 int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const 109 { 110 return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len); 111 } 112 113 int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const 114 { 115 return getsockname(m_socket, name, name_len); 116 } 117 118 bool Sock::SetNonBlocking() const 119 { 120 #ifdef WIN32 121 u_long on{1}; 122 if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) { 123 return false; 124 } 125 #else 126 const int flags{fcntl(m_socket, F_GETFL, 0)}; 127 if (flags == SOCKET_ERROR) { 128 return false; 129 } 130 if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) { 131 return false; 132 } 133 #endif 134 return true; 135 } 136 137 bool Sock::IsSelectable() const 138 { 139 #if defined(USE_POLL) || defined(WIN32) 140 return true; 141 #else 142 return m_socket < FD_SETSIZE; 143 #endif 144 } 145 146 bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const 147 { 148 // We need a `shared_ptr` holding `this` for `WaitMany()`, but don't want 149 // `this` to be destroyed when the `shared_ptr` goes out of scope at the 150 // end of this function. 151 // Create it with an aliasing shared_ptr that points to `this` without 152 // owning it. 153 std::shared_ptr<const Sock> shared{std::shared_ptr<const Sock>{}, this}; 154 155 EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})}; 156 157 if (!WaitMany(timeout, events_per_sock)) { 158 return false; 159 } 160 161 if (occurred != nullptr) { 162 *occurred = events_per_sock.begin()->second.occurred; 163 } 164 165 return true; 166 } 167 168 bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const 169 { 170 #ifdef USE_POLL 171 std::vector<pollfd> pfds; 172 for (const auto& [sock, events] : events_per_sock) { 173 pfds.emplace_back(); 174 auto& pfd = pfds.back(); 175 pfd.fd = sock->m_socket; 176 if (events.requested & RECV) { 177 pfd.events |= POLLIN; 178 } 179 if (events.requested & SEND) { 180 pfd.events |= POLLOUT; 181 } 182 } 183 184 if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) { 185 return false; 186 } 187 188 assert(pfds.size() == events_per_sock.size()); 189 size_t i{0}; 190 for (auto& [sock, events] : events_per_sock) { 191 assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd)); 192 events.occurred = 0; 193 if (pfds[i].revents & POLLIN) { 194 events.occurred |= RECV; 195 } 196 if (pfds[i].revents & POLLOUT) { 197 events.occurred |= SEND; 198 } 199 if (pfds[i].revents & (POLLERR | POLLHUP)) { 200 events.occurred |= ERR; 201 } 202 ++i; 203 } 204 205 return true; 206 #else 207 fd_set recv; 208 fd_set send; 209 fd_set err; 210 FD_ZERO(&recv); 211 FD_ZERO(&send); 212 FD_ZERO(&err); 213 SOCKET socket_max{0}; 214 215 for (const auto& [sock, events] : events_per_sock) { 216 if (!sock->IsSelectable()) { 217 return false; 218 } 219 const auto& s = sock->m_socket; 220 if (events.requested & RECV) { 221 FD_SET(s, &recv); 222 } 223 if (events.requested & SEND) { 224 FD_SET(s, &send); 225 } 226 FD_SET(s, &err); 227 socket_max = std::max(socket_max, s); 228 } 229 230 timeval tv = MillisToTimeval(timeout); 231 232 if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) { 233 return false; 234 } 235 236 for (auto& [sock, events] : events_per_sock) { 237 const auto& s = sock->m_socket; 238 events.occurred = 0; 239 if (FD_ISSET(s, &recv)) { 240 events.occurred |= RECV; 241 } 242 if (FD_ISSET(s, &send)) { 243 events.occurred |= SEND; 244 } 245 if (FD_ISSET(s, &err)) { 246 events.occurred |= ERR; 247 } 248 } 249 250 return true; 251 #endif /* USE_POLL */ 252 } 253 254 void Sock::SendComplete(std::span<const unsigned char> data, 255 std::chrono::milliseconds timeout, 256 CThreadInterrupt& interrupt) const 257 { 258 const auto deadline = GetTime<std::chrono::milliseconds>() + timeout; 259 size_t sent{0}; 260 261 for (;;) { 262 const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)}; 263 264 if (ret > 0) { 265 sent += static_cast<size_t>(ret); 266 if (sent == data.size()) { 267 break; 268 } 269 } else { 270 const int err{WSAGetLastError()}; 271 if (IOErrorIsPermanent(err)) { 272 throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err))); 273 } 274 } 275 276 const auto now = GetTime<std::chrono::milliseconds>(); 277 278 if (now >= deadline) { 279 throw std::runtime_error(strprintf( 280 "Send timeout (sent only %u of %u bytes before that)", sent, data.size())); 281 } 282 283 if (interrupt) { 284 throw std::runtime_error(strprintf( 285 "Send interrupted (sent only %u of %u bytes before that)", sent, data.size())); 286 } 287 288 // Wait for a short while (or the socket to become ready for sending) before retrying 289 // if nothing was sent. 290 const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO}); 291 (void)Wait(wait_time, SEND); 292 } 293 } 294 295 void Sock::SendComplete(std::span<const char> data, 296 std::chrono::milliseconds timeout, 297 CThreadInterrupt& interrupt) const 298 { 299 SendComplete(MakeUCharSpan(data), timeout, interrupt); 300 } 301 302 std::string Sock::RecvUntilTerminator(uint8_t terminator, 303 std::chrono::milliseconds timeout, 304 CThreadInterrupt& interrupt, 305 size_t max_data) const 306 { 307 const auto deadline = GetTime<std::chrono::milliseconds>() + timeout; 308 std::string data; 309 bool terminator_found{false}; 310 311 // We must not consume any bytes past the terminator from the socket. 312 // One option is to read one byte at a time and check if we have read a terminator. 313 // However that is very slow. Instead, we peek at what is in the socket and only read 314 // as many bytes as possible without crossing the terminator. 315 // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read 316 // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte 317 // at a time is about 50 times slower. 318 319 for (;;) { 320 if (data.size() >= max_data) { 321 throw std::runtime_error( 322 strprintf("Received too many bytes without a terminator (%u)", data.size())); 323 } 324 325 char buf[512]; 326 327 const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)}; 328 329 switch (peek_ret) { 330 case -1: { 331 const int err{WSAGetLastError()}; 332 if (IOErrorIsPermanent(err)) { 333 throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err))); 334 } 335 break; 336 } 337 case 0: 338 throw std::runtime_error("Connection unexpectedly closed by peer"); 339 default: 340 auto end = buf + peek_ret; 341 auto terminator_pos = std::find(buf, end, terminator); 342 terminator_found = terminator_pos != end; 343 344 const size_t try_len{terminator_found ? terminator_pos - buf + 1 : 345 static_cast<size_t>(peek_ret)}; 346 347 const ssize_t read_ret{Recv(buf, try_len, 0)}; 348 349 if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) { 350 throw std::runtime_error( 351 strprintf("recv() returned %u bytes on attempt to read %u bytes but previous " 352 "peek claimed %u bytes are available", 353 read_ret, try_len, peek_ret)); 354 } 355 356 // Don't include the terminator in the output. 357 const size_t append_len{terminator_found ? try_len - 1 : try_len}; 358 359 data.append(buf, buf + append_len); 360 361 if (terminator_found) { 362 return data; 363 } 364 } 365 366 const auto now = GetTime<std::chrono::milliseconds>(); 367 368 if (now >= deadline) { 369 throw std::runtime_error(strprintf( 370 "Receive timeout (received %u bytes without terminator before that)", data.size())); 371 } 372 373 if (interrupt) { 374 throw std::runtime_error(strprintf( 375 "Receive interrupted (received %u bytes without terminator before that)", 376 data.size())); 377 } 378 379 // Wait for a short while (or the socket to become ready for reading) before retrying. 380 const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO}); 381 (void)Wait(wait_time, RECV); 382 } 383 } 384 385 bool Sock::IsConnected(std::string& errmsg) const 386 { 387 if (m_socket == INVALID_SOCKET) { 388 errmsg = "not connected"; 389 return false; 390 } 391 392 char c; 393 switch (Recv(&c, sizeof(c), MSG_PEEK)) { 394 case -1: { 395 const int err = WSAGetLastError(); 396 if (IOErrorIsPermanent(err)) { 397 errmsg = NetworkErrorString(err); 398 return false; 399 } 400 return true; 401 } 402 case 0: 403 errmsg = "closed"; 404 return false; 405 default: 406 return true; 407 } 408 } 409 410 void Sock::Close() 411 { 412 if (m_socket == INVALID_SOCKET) { 413 return; 414 } 415 #ifdef WIN32 416 int ret = closesocket(m_socket); 417 #else 418 int ret = close(m_socket); 419 #endif 420 if (ret) { 421 LogWarning("Error closing socket %d: %s", m_socket, NetworkErrorString(WSAGetLastError())); 422 } 423 m_socket = INVALID_SOCKET; 424 } 425 426 bool Sock::operator==(SOCKET s) const 427 { 428 return m_socket == s; 429 }; 430 431 std::string NetworkErrorString(int err) 432 { 433 #if defined(WIN32) 434 return Win32ErrorString(err); 435 #else 436 // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString. 437 return SysErrorString(err); 438 #endif 439 }