sock.cpp
1 // Copyright (c) 2020-2022 The Bitcoin Core developers 2 // Distributed under the MIT software license, see the accompanying 3 // file COPYING or http://www.opensource.org/licenses/mit-license.php. 4 5 #include <common/system.h> 6 #include <compat/compat.h> 7 #include <logging.h> 8 #include <tinyformat.h> 9 #include <util/sock.h> 10 #include <util/syserror.h> 11 #include <util/threadinterrupt.h> 12 #include <util/time.h> 13 14 #include <memory> 15 #include <stdexcept> 16 #include <string> 17 18 #ifdef USE_POLL 19 #include <poll.h> 20 #endif 21 22 static inline bool IOErrorIsPermanent(int err) 23 { 24 return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS; 25 } 26 27 Sock::Sock(SOCKET s) : m_socket(s) {} 28 29 Sock::Sock(Sock&& other) 30 { 31 m_socket = other.m_socket; 32 other.m_socket = INVALID_SOCKET; 33 } 34 35 Sock::~Sock() { Close(); } 36 37 Sock& Sock::operator=(Sock&& other) 38 { 39 Close(); 40 m_socket = other.m_socket; 41 other.m_socket = INVALID_SOCKET; 42 return *this; 43 } 44 45 ssize_t Sock::Send(const void* data, size_t len, int flags) const 46 { 47 return send(m_socket, static_cast<const char*>(data), len, flags); 48 } 49 50 ssize_t Sock::Recv(void* buf, size_t len, int flags) const 51 { 52 return recv(m_socket, static_cast<char*>(buf), len, flags); 53 } 54 55 int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const 56 { 57 return connect(m_socket, addr, addr_len); 58 } 59 60 int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const 61 { 62 return bind(m_socket, addr, addr_len); 63 } 64 65 int Sock::Listen(int backlog) const 66 { 67 return listen(m_socket, backlog); 68 } 69 70 std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const 71 { 72 #ifdef WIN32 73 static constexpr auto ERR = INVALID_SOCKET; 74 #else 75 static constexpr auto ERR = SOCKET_ERROR; 76 #endif 77 78 std::unique_ptr<Sock> sock; 79 80 const auto socket = accept(m_socket, addr, addr_len); 81 if (socket != ERR) { 82 try { 83 sock = std::make_unique<Sock>(socket); 84 } catch (const std::exception&) { 85 #ifdef WIN32 86 closesocket(socket); 87 #else 88 close(socket); 89 #endif 90 } 91 } 92 93 return sock; 94 } 95 96 int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const 97 { 98 return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len); 99 } 100 101 int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const 102 { 103 return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len); 104 } 105 106 int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const 107 { 108 return getsockname(m_socket, name, name_len); 109 } 110 111 bool Sock::SetNonBlocking() const 112 { 113 #ifdef WIN32 114 u_long on{1}; 115 if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) { 116 return false; 117 } 118 #else 119 const int flags{fcntl(m_socket, F_GETFL, 0)}; 120 if (flags == SOCKET_ERROR) { 121 return false; 122 } 123 if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) { 124 return false; 125 } 126 #endif 127 return true; 128 } 129 130 bool Sock::IsSelectable() const 131 { 132 #if defined(USE_POLL) || defined(WIN32) 133 return true; 134 #else 135 return m_socket < FD_SETSIZE; 136 #endif 137 } 138 139 bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const 140 { 141 // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want 142 // `this` to be destroyed when the `shared_ptr` goes out of scope at the 143 // end of this function. Create it with a custom noop deleter. 144 std::shared_ptr<const Sock> shared{this, [](const Sock*) {}}; 145 146 EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})}; 147 148 if (!WaitMany(timeout, events_per_sock)) { 149 return false; 150 } 151 152 if (occurred != nullptr) { 153 *occurred = events_per_sock.begin()->second.occurred; 154 } 155 156 return true; 157 } 158 159 bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const 160 { 161 #ifdef USE_POLL 162 std::vector<pollfd> pfds; 163 for (const auto& [sock, events] : events_per_sock) { 164 pfds.emplace_back(); 165 auto& pfd = pfds.back(); 166 pfd.fd = sock->m_socket; 167 if (events.requested & RECV) { 168 pfd.events |= POLLIN; 169 } 170 if (events.requested & SEND) { 171 pfd.events |= POLLOUT; 172 } 173 } 174 175 if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) { 176 return false; 177 } 178 179 assert(pfds.size() == events_per_sock.size()); 180 size_t i{0}; 181 for (auto& [sock, events] : events_per_sock) { 182 assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd)); 183 events.occurred = 0; 184 if (pfds[i].revents & POLLIN) { 185 events.occurred |= RECV; 186 } 187 if (pfds[i].revents & POLLOUT) { 188 events.occurred |= SEND; 189 } 190 if (pfds[i].revents & (POLLERR | POLLHUP)) { 191 events.occurred |= ERR; 192 } 193 ++i; 194 } 195 196 return true; 197 #else 198 fd_set recv; 199 fd_set send; 200 fd_set err; 201 FD_ZERO(&recv); 202 FD_ZERO(&send); 203 FD_ZERO(&err); 204 SOCKET socket_max{0}; 205 206 for (const auto& [sock, events] : events_per_sock) { 207 if (!sock->IsSelectable()) { 208 return false; 209 } 210 const auto& s = sock->m_socket; 211 if (events.requested & RECV) { 212 FD_SET(s, &recv); 213 } 214 if (events.requested & SEND) { 215 FD_SET(s, &send); 216 } 217 FD_SET(s, &err); 218 socket_max = std::max(socket_max, s); 219 } 220 221 timeval tv = MillisToTimeval(timeout); 222 223 if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) { 224 return false; 225 } 226 227 for (auto& [sock, events] : events_per_sock) { 228 const auto& s = sock->m_socket; 229 events.occurred = 0; 230 if (FD_ISSET(s, &recv)) { 231 events.occurred |= RECV; 232 } 233 if (FD_ISSET(s, &send)) { 234 events.occurred |= SEND; 235 } 236 if (FD_ISSET(s, &err)) { 237 events.occurred |= ERR; 238 } 239 } 240 241 return true; 242 #endif /* USE_POLL */ 243 } 244 245 void Sock::SendComplete(Span<const unsigned char> data, 246 std::chrono::milliseconds timeout, 247 CThreadInterrupt& interrupt) const 248 { 249 const auto deadline = GetTime<std::chrono::milliseconds>() + timeout; 250 size_t sent{0}; 251 252 for (;;) { 253 const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)}; 254 255 if (ret > 0) { 256 sent += static_cast<size_t>(ret); 257 if (sent == data.size()) { 258 break; 259 } 260 } else { 261 const int err{WSAGetLastError()}; 262 if (IOErrorIsPermanent(err)) { 263 throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err))); 264 } 265 } 266 267 const auto now = GetTime<std::chrono::milliseconds>(); 268 269 if (now >= deadline) { 270 throw std::runtime_error(strprintf( 271 "Send timeout (sent only %u of %u bytes before that)", sent, data.size())); 272 } 273 274 if (interrupt) { 275 throw std::runtime_error(strprintf( 276 "Send interrupted (sent only %u of %u bytes before that)", sent, data.size())); 277 } 278 279 // Wait for a short while (or the socket to become ready for sending) before retrying 280 // if nothing was sent. 281 const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO}); 282 (void)Wait(wait_time, SEND); 283 } 284 } 285 286 void Sock::SendComplete(Span<const char> data, 287 std::chrono::milliseconds timeout, 288 CThreadInterrupt& interrupt) const 289 { 290 SendComplete(MakeUCharSpan(data), timeout, interrupt); 291 } 292 293 std::string Sock::RecvUntilTerminator(uint8_t terminator, 294 std::chrono::milliseconds timeout, 295 CThreadInterrupt& interrupt, 296 size_t max_data) const 297 { 298 const auto deadline = GetTime<std::chrono::milliseconds>() + timeout; 299 std::string data; 300 bool terminator_found{false}; 301 302 // We must not consume any bytes past the terminator from the socket. 303 // One option is to read one byte at a time and check if we have read a terminator. 304 // However that is very slow. Instead, we peek at what is in the socket and only read 305 // as many bytes as possible without crossing the terminator. 306 // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read 307 // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte 308 // at a time is about 50 times slower. 309 310 for (;;) { 311 if (data.size() >= max_data) { 312 throw std::runtime_error( 313 strprintf("Received too many bytes without a terminator (%u)", data.size())); 314 } 315 316 char buf[512]; 317 318 const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)}; 319 320 switch (peek_ret) { 321 case -1: { 322 const int err{WSAGetLastError()}; 323 if (IOErrorIsPermanent(err)) { 324 throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err))); 325 } 326 break; 327 } 328 case 0: 329 throw std::runtime_error("Connection unexpectedly closed by peer"); 330 default: 331 auto end = buf + peek_ret; 332 auto terminator_pos = std::find(buf, end, terminator); 333 terminator_found = terminator_pos != end; 334 335 const size_t try_len{terminator_found ? terminator_pos - buf + 1 : 336 static_cast<size_t>(peek_ret)}; 337 338 const ssize_t read_ret{Recv(buf, try_len, 0)}; 339 340 if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) { 341 throw std::runtime_error( 342 strprintf("recv() returned %u bytes on attempt to read %u bytes but previous " 343 "peek claimed %u bytes are available", 344 read_ret, try_len, peek_ret)); 345 } 346 347 // Don't include the terminator in the output. 348 const size_t append_len{terminator_found ? try_len - 1 : try_len}; 349 350 data.append(buf, buf + append_len); 351 352 if (terminator_found) { 353 return data; 354 } 355 } 356 357 const auto now = GetTime<std::chrono::milliseconds>(); 358 359 if (now >= deadline) { 360 throw std::runtime_error(strprintf( 361 "Receive timeout (received %u bytes without terminator before that)", data.size())); 362 } 363 364 if (interrupt) { 365 throw std::runtime_error(strprintf( 366 "Receive interrupted (received %u bytes without terminator before that)", 367 data.size())); 368 } 369 370 // Wait for a short while (or the socket to become ready for reading) before retrying. 371 const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO}); 372 (void)Wait(wait_time, RECV); 373 } 374 } 375 376 bool Sock::IsConnected(std::string& errmsg) const 377 { 378 if (m_socket == INVALID_SOCKET) { 379 errmsg = "not connected"; 380 return false; 381 } 382 383 char c; 384 switch (Recv(&c, sizeof(c), MSG_PEEK)) { 385 case -1: { 386 const int err = WSAGetLastError(); 387 if (IOErrorIsPermanent(err)) { 388 errmsg = NetworkErrorString(err); 389 return false; 390 } 391 return true; 392 } 393 case 0: 394 errmsg = "closed"; 395 return false; 396 default: 397 return true; 398 } 399 } 400 401 void Sock::Close() 402 { 403 if (m_socket == INVALID_SOCKET) { 404 return; 405 } 406 #ifdef WIN32 407 int ret = closesocket(m_socket); 408 #else 409 int ret = close(m_socket); 410 #endif 411 if (ret) { 412 LogPrintf("Error closing socket %d: %s\n", m_socket, NetworkErrorString(WSAGetLastError())); 413 } 414 m_socket = INVALID_SOCKET; 415 } 416 417 bool Sock::operator==(SOCKET s) const 418 { 419 return m_socket == s; 420 }; 421 422 std::string NetworkErrorString(int err) 423 { 424 #if defined(WIN32) 425 return Win32ErrorString(err); 426 #else 427 // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString. 428 return SysErrorString(err); 429 #endif 430 }