/ ggml-rpc.cpp
ggml-rpc.cpp
1 #include "ggml-rpc.h" 2 #include "ggml.h" 3 #include "ggml-backend-impl.h" 4 5 #include <cinttypes> 6 #include <string> 7 #include <vector> 8 #include <memory> 9 #include <mutex> 10 #include <unordered_map> 11 #include <unordered_set> 12 #ifdef _WIN32 13 # define WIN32_LEAN_AND_MEAN 14 # ifndef NOMINMAX 15 # define NOMINMAX 16 # endif 17 # include <windows.h> 18 # include <winsock2.h> 19 #else 20 # include <arpa/inet.h> 21 # include <sys/socket.h> 22 # include <sys/types.h> 23 # include <netinet/in.h> 24 # include <netinet/tcp.h> 25 # include <netdb.h> 26 # include <unistd.h> 27 #endif 28 #include <string.h> 29 30 #define UNUSED GGML_UNUSED 31 32 #define GGML_DEBUG 0 33 #if (GGML_DEBUG >= 1) 34 #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) 35 #else 36 #define GGML_PRINT_DEBUG(...) 37 #endif 38 39 #ifdef _WIN32 40 typedef SOCKET sockfd_t; 41 using ssize_t = __int64; 42 #else 43 typedef int sockfd_t; 44 #endif 45 46 // cross-platform socket 47 struct socket_t { 48 sockfd_t fd; 49 socket_t(sockfd_t fd) : fd(fd) {} 50 ~socket_t() { 51 GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd); 52 #ifdef _WIN32 53 closesocket(this->fd); 54 #else 55 close(this->fd); 56 #endif 57 } 58 }; 59 60 // ggml_tensor is serialized into rpc_tensor 61 #pragma pack(push, 1) 62 struct rpc_tensor { 63 uint64_t id; 64 uint32_t type; 65 uint64_t buffer; 66 uint32_t ne[GGML_MAX_DIMS]; 67 uint32_t nb[GGML_MAX_DIMS]; 68 uint32_t op; 69 int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; 70 int32_t flags; 71 uint64_t src[GGML_MAX_SRC]; 72 uint64_t view_src; 73 uint64_t view_offs; 74 uint64_t data; 75 char name[GGML_MAX_NAME]; 76 77 char padding[4]; 78 }; 79 #pragma pack(pop) 80 81 static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8"); 82 83 // RPC commands 84 enum rpc_cmd { 85 ALLOC_BUFFER = 0, 86 GET_ALIGNMENT, 87 GET_MAX_SIZE, 88 BUFFER_GET_BASE, 89 FREE_BUFFER, 90 BUFFER_CLEAR, 91 SET_TENSOR, 92 GET_TENSOR, 93 COPY_TENSOR, 94 GRAPH_COMPUTE, 95 GET_DEVICE_MEMORY, 96 }; 97 98 // RPC data structures 99 100 static ggml_guid_t ggml_backend_rpc_guid() { 101 static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03}; 102 return &guid; 103 } 104 105 struct ggml_backend_rpc_buffer_type_context { 106 std::string endpoint; 107 std::string name; 108 size_t alignment; 109 size_t max_size; 110 }; 111 112 struct ggml_backend_rpc_context { 113 std::string endpoint; 114 std::string name; 115 }; 116 117 struct ggml_backend_rpc_buffer_context { 118 std::shared_ptr<socket_t> sock; 119 std::unordered_map<ggml_backend_buffer_t, void *> base_cache; 120 uint64_t remote_ptr; 121 std::string name; 122 }; 123 124 // RPC helper functions 125 126 static std::shared_ptr<socket_t> make_socket(sockfd_t fd) { 127 #ifdef _WIN32 128 if (fd == INVALID_SOCKET) { 129 return nullptr; 130 } 131 #else 132 if (fd < 0) { 133 return nullptr; 134 } 135 #endif 136 return std::make_shared<socket_t>(fd); 137 } 138 139 static bool set_no_delay(sockfd_t sockfd) { 140 int flag = 1; 141 // set TCP_NODELAY to disable Nagle's algorithm 142 int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); 143 return ret == 0; 144 } 145 146 static bool set_reuse_addr(sockfd_t sockfd) { 147 int flag = 1; 148 int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); 149 return ret == 0; 150 } 151 152 static std::shared_ptr<socket_t> socket_connect(const char * host, int port) { 153 struct sockaddr_in addr; 154 auto sockfd = socket(AF_INET, SOCK_STREAM, 0); 155 auto sock_ptr = make_socket(sockfd); 156 if (sock_ptr == nullptr) { 157 return nullptr; 158 } 159 if (!set_no_delay(sockfd)) { 160 fprintf(stderr, "Failed to set TCP_NODELAY\n"); 161 return nullptr; 162 } 163 addr.sin_family = AF_INET; 164 addr.sin_port = htons(port); 165 struct hostent * server = gethostbyname(host); 166 if (server == NULL) { 167 fprintf(stderr, "Cannot resolve host '%s'\n", host); 168 return nullptr; 169 } 170 memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); 171 if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { 172 return nullptr; 173 } 174 return sock_ptr; 175 } 176 177 static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) { 178 auto client_socket_fd = accept(srv_sockfd, NULL, NULL); 179 auto client_socket = make_socket(client_socket_fd); 180 if (client_socket == nullptr) { 181 return nullptr; 182 } 183 if (!set_no_delay(client_socket_fd)) { 184 fprintf(stderr, "Failed to set TCP_NODELAY\n"); 185 return nullptr; 186 } 187 return client_socket; 188 } 189 190 static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) { 191 auto sockfd = socket(AF_INET, SOCK_STREAM, 0); 192 auto sock = make_socket(sockfd); 193 if (sock == nullptr) { 194 return nullptr; 195 } 196 if (!set_reuse_addr(sockfd)) { 197 fprintf(stderr, "Failed to set SO_REUSEADDR\n"); 198 return nullptr; 199 } 200 struct sockaddr_in serv_addr; 201 serv_addr.sin_family = AF_INET; 202 serv_addr.sin_addr.s_addr = inet_addr(host); 203 serv_addr.sin_port = htons(port); 204 205 if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { 206 return nullptr; 207 } 208 if (listen(sockfd, 1) < 0) { 209 return nullptr; 210 } 211 return sock; 212 } 213 214 static bool send_data(sockfd_t sockfd, const void * data, size_t size) { 215 size_t bytes_sent = 0; 216 while (bytes_sent < size) { 217 ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0); 218 if (n < 0) { 219 return false; 220 } 221 bytes_sent += n; 222 } 223 return true; 224 } 225 226 static bool recv_data(sockfd_t sockfd, void * data, size_t size) { 227 size_t bytes_recv = 0; 228 while (bytes_recv < size) { 229 ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0); 230 if (n <= 0) { 231 return false; 232 } 233 bytes_recv += n; 234 } 235 return true; 236 } 237 238 static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { 239 size_t pos = endpoint.find(':'); 240 if (pos == std::string::npos) { 241 return false; 242 } 243 host = endpoint.substr(0, pos); 244 port = std::stoi(endpoint.substr(pos + 1)); 245 return true; 246 } 247 248 // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | 249 // RPC response: | response_size (8 bytes) | response_data (response_size bytes) | 250 static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { 251 uint8_t cmd_byte = cmd; 252 if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { 253 return false; 254 } 255 uint64_t input_size = input.size(); 256 if (!send_data(sock->fd, &input_size, sizeof(input_size))) { 257 return false; 258 } 259 if (!send_data(sock->fd, input.data(), input.size())) { 260 return false; 261 } 262 uint64_t output_size; 263 if (!recv_data(sock->fd, &output_size, sizeof(output_size))) { 264 return false; 265 } 266 if (output_size == 0) { 267 output.clear(); 268 return true; 269 } 270 output.resize(output_size); 271 if (!recv_data(sock->fd, output.data(), output_size)) { 272 return false; 273 } 274 return true; 275 } 276 277 // RPC client-side implementation 278 279 static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) { 280 static std::mutex mutex; 281 std::lock_guard<std::mutex> lock(mutex); 282 static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets; 283 static bool initialized = false; 284 285 auto it = sockets.find(endpoint); 286 if (it != sockets.end()) { 287 if (auto sock = it->second.lock()) { 288 return sock; 289 } 290 } 291 std::string host; 292 int port; 293 if (!parse_endpoint(endpoint, host, port)) { 294 return nullptr; 295 } 296 #ifdef _WIN32 297 if (!initialized) { 298 WSADATA wsaData; 299 int res = WSAStartup(MAKEWORD(2, 2), &wsaData); 300 if (res != 0) { 301 return nullptr; 302 } 303 initialized = true; 304 } 305 #else 306 UNUSED(initialized); 307 #endif 308 auto sock = socket_connect(host.c_str(), port); 309 if (sock == nullptr) { 310 return nullptr; 311 } 312 GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); 313 sockets[endpoint] = sock; 314 return sock; 315 } 316 317 GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) { 318 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; 319 return ctx->name.c_str(); 320 } 321 322 GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { 323 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; 324 // input serialization format: | remote_ptr (8 bytes) | 325 std::vector<uint8_t> input(sizeof(uint64_t), 0); 326 uint64_t remote_ptr = ctx->remote_ptr; 327 memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); 328 std::vector<uint8_t> output; 329 bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output); 330 GGML_ASSERT(status); 331 GGML_ASSERT(output.empty()); 332 delete ctx; 333 } 334 335 GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { 336 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; 337 if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) { 338 return ctx->base_cache[buffer]; 339 } 340 // input serialization format: | remote_ptr (8 bytes) | 341 std::vector<uint8_t> input(sizeof(uint64_t), 0); 342 uint64_t remote_ptr = ctx->remote_ptr; 343 memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); 344 std::vector<uint8_t> output; 345 bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output); 346 GGML_ASSERT(status); 347 GGML_ASSERT(output.size() == sizeof(uint64_t)); 348 // output serialization format: | base_ptr (8 bytes) | 349 uint64_t base_ptr; 350 memcpy(&base_ptr, output.data(), sizeof(base_ptr)); 351 void * base = reinterpret_cast<void *>(base_ptr); 352 ctx->base_cache[buffer] = base; 353 return base; 354 } 355 356 static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { 357 rpc_tensor result; 358 result.id = reinterpret_cast<uint64_t>(tensor); 359 result.type = tensor->type; 360 if (tensor->buffer) { 361 ggml_backend_buffer_t buffer = tensor->buffer; 362 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; 363 result.buffer = ctx->remote_ptr; 364 } else { 365 result.buffer = 0; 366 } 367 for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { 368 result.ne[i] = tensor->ne[i]; 369 result.nb[i] = tensor->nb[i]; 370 } 371 result.op = tensor->op; 372 for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { 373 result.op_params[i] = tensor->op_params[i]; 374 } 375 result.flags = tensor->flags; 376 for (uint32_t i = 0; i < GGML_MAX_SRC; i++) { 377 result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]); 378 } 379 result.view_src = reinterpret_cast<uint64_t>(tensor->view_src); 380 result.view_offs = tensor->view_offs; 381 result.data = reinterpret_cast<uint64_t>(tensor->data); 382 snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name); 383 return result; 384 } 385 386 GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { 387 UNUSED(buffer); 388 if (ggml_is_quantized(tensor->type)) { 389 // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized 390 GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor"); 391 } 392 } 393 394 GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { 395 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; 396 // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) | 397 size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size; 398 std::vector<uint8_t> input(input_size, 0); 399 rpc_tensor rpc_tensor = serialize_tensor(tensor); 400 memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); 401 memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); 402 memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); 403 std::vector<uint8_t> output; 404 bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output); 405 GGML_ASSERT(status); 406 } 407 408 GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { 409 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; 410 // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) | 411 int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t); 412 std::vector<uint8_t> input(input_size, 0); 413 rpc_tensor rpc_tensor = serialize_tensor(tensor); 414 memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); 415 memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); 416 memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size)); 417 std::vector<uint8_t> output; 418 bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output); 419 GGML_ASSERT(status); 420 GGML_ASSERT(output.size() == size); 421 // output serialization format: | data (size bytes) | 422 memcpy(data, output.data(), size); 423 } 424 425 GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { 426 // check if src and dst are on the same server 427 ggml_backend_buffer_t src_buffer = src->buffer; 428 ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; 429 ggml_backend_buffer_t dst_buffer = dst->buffer; 430 ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; 431 if (src_ctx->sock != dst_ctx->sock) { 432 return false; 433 } 434 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; 435 // input serialization format: | rpc_tensor src | rpc_tensor dst | 436 int input_size = 2*sizeof(rpc_tensor); 437 std::vector<uint8_t> input(input_size, 0); 438 rpc_tensor rpc_src = serialize_tensor(src); 439 rpc_tensor rpc_dst = serialize_tensor(dst); 440 memcpy(input.data(), &rpc_src, sizeof(rpc_src)); 441 memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst)); 442 std::vector<uint8_t> output; 443 bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output); 444 GGML_ASSERT(status); 445 // output serialization format: | result (1 byte) | 446 GGML_ASSERT(output.size() == 1); 447 return output[0]; 448 } 449 450 GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { 451 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; 452 // serialization format: | bufptr (8 bytes) | value (1 byte) | 453 int input_size = sizeof(uint64_t) + sizeof(uint8_t); 454 std::vector<uint8_t> input(input_size, 0); 455 memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr)); 456 memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value)); 457 std::vector<uint8_t> output; 458 bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output); 459 GGML_ASSERT(status); 460 } 461 462 static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { 463 /* .get_name = */ ggml_backend_rpc_buffer_get_name, 464 /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer, 465 /* .get_base = */ ggml_backend_rpc_buffer_get_base, 466 /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor, 467 /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, 468 /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, 469 /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, 470 /* .clear = */ ggml_backend_rpc_buffer_clear, 471 /* .reset = */ NULL, 472 }; 473 474 GGML_CALL static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) { 475 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; 476 return buft_ctx->name.c_str(); 477 } 478 479 GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { 480 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; 481 // input serialization format: | size (8 bytes) | 482 int input_size = sizeof(uint64_t); 483 std::vector<uint8_t> input(input_size, 0); 484 memcpy(input.data(), &size, sizeof(size)); 485 std::vector<uint8_t> output; 486 auto sock = get_socket(buft_ctx->endpoint); 487 bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output); 488 GGML_ASSERT(status); 489 GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); 490 // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | 491 uint64_t remote_ptr; 492 memcpy(&remote_ptr, output.data(), sizeof(remote_ptr)); 493 size_t remote_size; 494 memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size)); 495 if (remote_ptr != 0) { 496 ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, 497 ggml_backend_rpc_buffer_interface, 498 new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"}, 499 remote_size); 500 return buffer; 501 } else { 502 return nullptr; 503 } 504 } 505 506 static size_t get_alignment(const std::shared_ptr<socket_t> & sock) { 507 // input serialization format: | 0 bytes | 508 std::vector<uint8_t> input; 509 std::vector<uint8_t> output; 510 bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output); 511 GGML_ASSERT(status); 512 GGML_ASSERT(output.size() == sizeof(uint64_t)); 513 // output serialization format: | alignment (8 bytes) | 514 uint64_t alignment; 515 memcpy(&alignment, output.data(), sizeof(alignment)); 516 return alignment; 517 } 518 519 GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { 520 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; 521 return buft_ctx->alignment; 522 } 523 524 static size_t get_max_size(const std::shared_ptr<socket_t> & sock) { 525 // input serialization format: | 0 bytes | 526 std::vector<uint8_t> input; 527 std::vector<uint8_t> output; 528 bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output); 529 GGML_ASSERT(status); 530 GGML_ASSERT(output.size() == sizeof(uint64_t)); 531 // output serialization format: | max_size (8 bytes) | 532 uint64_t max_size; 533 memcpy(&max_size, output.data(), sizeof(max_size)); 534 return max_size; 535 } 536 537 GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { 538 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; 539 return buft_ctx->max_size; 540 } 541 542 GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { 543 UNUSED(buft); 544 return ggml_nbytes(tensor); 545 } 546 547 static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { 548 /* .get_name = */ ggml_backend_rpc_buffer_type_name, 549 /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer, 550 /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment, 551 /* .get_max_size = */ ggml_backend_rpc_get_max_size, 552 /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size, 553 /* .is_host = */ NULL, 554 }; 555 556 GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { 557 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; 558 559 return rpc_ctx->name.c_str(); 560 } 561 562 GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) { 563 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; 564 delete rpc_ctx; 565 delete backend; 566 } 567 568 GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) { 569 ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; 570 return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); 571 } 572 573 GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { 574 UNUSED(backend); 575 // this is no-op because we don't have any async operations 576 } 577 578 static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) { 579 if (tensor == nullptr) { 580 return; 581 } 582 if (visited.find(tensor) != visited.end()) { 583 return; 584 } 585 visited.insert(tensor); 586 for (int i = 0; i < GGML_MAX_SRC; i++) { 587 add_tensor(tensor->src[i], tensors, visited); 588 } 589 add_tensor(tensor->view_src, tensors, visited); 590 tensors.push_back(serialize_tensor(tensor)); 591 } 592 593 static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) { 594 uint32_t n_nodes = cgraph->n_nodes; 595 std::vector<rpc_tensor> tensors; 596 std::unordered_set<ggml_tensor*> visited; 597 for (uint32_t i = 0; i < n_nodes; i++) { 598 add_tensor(cgraph->nodes[i], tensors, visited); 599 } 600 // serialization format: 601 // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | 602 uint32_t n_tensors = tensors.size(); 603 int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); 604 output.resize(output_size, 0); 605 memcpy(output.data(), &n_nodes, sizeof(n_nodes)); 606 for (uint32_t i = 0; i < n_nodes; i++) { 607 memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); 608 } 609 uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); 610 *out_ntensors = n_tensors; 611 rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t)); 612 memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); 613 } 614 615 GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { 616 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; 617 std::vector<uint8_t> input; 618 serialize_graph(cgraph, input); 619 std::vector<uint8_t> output; 620 auto sock = get_socket(rpc_ctx->endpoint); 621 bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output); 622 GGML_ASSERT(status); 623 GGML_ASSERT(output.size() == 1); 624 return (enum ggml_status)output[0]; 625 } 626 627 GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) { 628 UNUSED(backend); 629 UNUSED(op); 630 //TODO: call the remote backend and cache the results 631 return true; 632 } 633 634 GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { 635 if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) { 636 return false; 637 } 638 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; 639 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; 640 return buft_ctx->endpoint == rpc_ctx->endpoint; 641 } 642 643 static ggml_backend_i ggml_backend_rpc_interface = { 644 /* .get_name = */ ggml_backend_rpc_name, 645 /* .free = */ ggml_backend_rpc_free, 646 /* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type, 647 /* .set_tensor_async = */ NULL, 648 /* .get_tensor_async = */ NULL, 649 /* .cpy_tensor_async = */ NULL, 650 /* .synchronize = */ ggml_backend_rpc_synchronize, 651 /* .graph_plan_create = */ NULL, 652 /* .graph_plan_free = */ NULL, 653 /* .graph_plan_update = */ NULL, 654 /* .graph_plan_compute = */ NULL, 655 /* .graph_compute = */ ggml_backend_rpc_graph_compute, 656 /* .supports_op = */ ggml_backend_rpc_supports_op, 657 /* .supports_buft = */ ggml_backend_rpc_supports_buft, 658 /* .offload_op = */ NULL, 659 /* .event_new = */ NULL, 660 /* .event_free = */ NULL, 661 /* .event_record = */ NULL, 662 /* .event_wait = */ NULL, 663 /* .event_synchronize = */ NULL, 664 }; 665 666 GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { 667 static std::mutex mutex; 668 std::lock_guard<std::mutex> lock(mutex); 669 // NOTE: buffer types are allocated and never freed; this is by design 670 static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map; 671 auto it = buft_map.find(endpoint); 672 if (it != buft_map.end()) { 673 return it->second; 674 } 675 auto sock = get_socket(endpoint); 676 if (sock == nullptr) { 677 return nullptr; 678 } 679 size_t alignment = get_alignment(sock); 680 size_t max_size = get_max_size(sock); 681 ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { 682 /* .endpoint = */ endpoint, 683 /* .name = */ "RPC[" + std::string(endpoint) + "]", 684 /* .alignment = */ alignment, 685 /* .max_size = */ max_size 686 }; 687 688 ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { 689 /* .iface = */ ggml_backend_rpc_buffer_type_interface, 690 /* .context = */ buft_ctx 691 }; 692 buft_map[endpoint] = buft; 693 return buft; 694 } 695 696 GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { 697 ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { 698 /* .endpoint = */ endpoint, 699 /* .name = */ "RPC[" + std::string(endpoint) + "]", 700 }; 701 702 ggml_backend_t backend = new ggml_backend { 703 /* .guid = */ ggml_backend_rpc_guid(), 704 /* .interface = */ ggml_backend_rpc_interface, 705 /* .context = */ ctx 706 }; 707 return backend; 708 } 709 710 GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { 711 return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); 712 } 713 714 static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) { 715 // input serialization format: | 0 bytes | 716 std::vector<uint8_t> input; 717 std::vector<uint8_t> output; 718 bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output); 719 GGML_ASSERT(status); 720 GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); 721 // output serialization format: | free (8 bytes) | total (8 bytes) | 722 uint64_t free_mem; 723 memcpy(&free_mem, output.data(), sizeof(free_mem)); 724 uint64_t total_mem; 725 memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem)); 726 *free = free_mem; 727 *total = total_mem; 728 } 729 730 GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { 731 auto sock = get_socket(endpoint); 732 if (sock == nullptr) { 733 *free = 0; 734 *total = 0; 735 return; 736 } 737 get_device_memory(sock, free, total); 738 } 739 740 // RPC server-side implementation 741 742 class rpc_server { 743 public: 744 rpc_server(ggml_backend_t backend) : backend(backend) {} 745 ~rpc_server(); 746 747 bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); 748 void get_alignment(std::vector<uint8_t> & output); 749 void get_max_size(std::vector<uint8_t> & output); 750 bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); 751 bool free_buffer(const std::vector<uint8_t> & input); 752 bool buffer_clear(const std::vector<uint8_t> & input); 753 bool set_tensor(const std::vector<uint8_t> & input); 754 bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); 755 bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); 756 bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); 757 758 private: 759 ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); 760 ggml_tensor * create_node(uint64_t id, 761 struct ggml_context * ctx, 762 const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs, 763 std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map); 764 765 766 ggml_backend_t backend; 767 std::unordered_set<ggml_backend_buffer_t> buffers; 768 }; 769 770 bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { 771 // input serialization format: | size (8 bytes) | 772 if (input.size() != sizeof(uint64_t)) { 773 return false; 774 } 775 uint64_t size; 776 memcpy(&size, input.data(), sizeof(size)); 777 ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); 778 ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size); 779 uint64_t remote_ptr = 0; 780 uint64_t remote_size = 0; 781 if (buffer != nullptr) { 782 remote_ptr = reinterpret_cast<uint64_t>(buffer); 783 remote_size = buffer->size; 784 GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size); 785 buffers.insert(buffer); 786 } else { 787 GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size); 788 } 789 // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | 790 output.resize(2*sizeof(uint64_t), 0); 791 memcpy(output.data(), &remote_ptr, sizeof(remote_ptr)); 792 memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size)); 793 return true; 794 } 795 796 void rpc_server::get_alignment(std::vector<uint8_t> & output) { 797 ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); 798 size_t alignment = ggml_backend_buft_get_alignment(buft); 799 GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment); 800 // output serialization format: | alignment (8 bytes) | 801 output.resize(sizeof(uint64_t), 0); 802 memcpy(output.data(), &alignment, sizeof(alignment)); 803 } 804 805 void rpc_server::get_max_size(std::vector<uint8_t> & output) { 806 ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); 807 size_t max_size = ggml_backend_buft_get_max_size(buft); 808 GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size); 809 // output serialization format: | max_size (8 bytes) | 810 output.resize(sizeof(uint64_t), 0); 811 memcpy(output.data(), &max_size, sizeof(max_size)); 812 } 813 814 bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { 815 // input serialization format: | remote_ptr (8 bytes) | 816 if (input.size() != sizeof(uint64_t)) { 817 return false; 818 } 819 uint64_t remote_ptr; 820 memcpy(&remote_ptr, input.data(), sizeof(remote_ptr)); 821 GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr); 822 ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr); 823 if (buffers.find(buffer) == buffers.end()) { 824 GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); 825 return false; 826 } 827 void * base = ggml_backend_buffer_get_base(buffer); 828 // output serialization format: | base_ptr (8 bytes) | 829 uint64_t base_ptr = reinterpret_cast<uint64_t>(base); 830 output.resize(sizeof(uint64_t), 0); 831 memcpy(output.data(), &base_ptr, sizeof(base_ptr)); 832 return true; 833 } 834 835 bool rpc_server::free_buffer(const std::vector<uint8_t> & input) { 836 // input serialization format: | remote_ptr (8 bytes) | 837 if (input.size() != sizeof(uint64_t)) { 838 return false; 839 } 840 uint64_t remote_ptr; 841 memcpy(&remote_ptr, input.data(), sizeof(remote_ptr)); 842 GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr); 843 ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr); 844 if (buffers.find(buffer) == buffers.end()) { 845 GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); 846 return false; 847 } 848 ggml_backend_buffer_free(buffer); 849 buffers.erase(buffer); 850 return true; 851 } 852 853 bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) { 854 // input serialization format: | remote_ptr (8 bytes) | value (1 byte) | 855 if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) { 856 return false; 857 } 858 uint64_t remote_ptr; 859 memcpy(&remote_ptr, input.data(), sizeof(remote_ptr)); 860 uint8_t value; 861 memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value)); 862 GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value); 863 ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr); 864 if (buffers.find(buffer) == buffers.end()) { 865 GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); 866 return false; 867 } 868 ggml_backend_buffer_clear(buffer, value); 869 return true; 870 } 871 872 ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) { 873 ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, 874 tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); 875 for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { 876 result->nb[i] = tensor->nb[i]; 877 } 878 result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer); 879 if (result->buffer && buffers.find(result->buffer) == buffers.end()) { 880 return nullptr; 881 } 882 result->op = (ggml_op) tensor->op; 883 for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { 884 result->op_params[i] = tensor->op_params[i]; 885 } 886 result->flags = tensor->flags; 887 result->data = reinterpret_cast<void *>(tensor->data); 888 ggml_set_name(result, tensor->name); 889 return result; 890 } 891 892 893 bool rpc_server::set_tensor(const std::vector<uint8_t> & input) { 894 // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) | 895 if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) { 896 return false; 897 } 898 const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); 899 uint64_t offset; 900 memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); 901 size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset); 902 903 struct ggml_init_params params { 904 /*.mem_size =*/ ggml_tensor_overhead(), 905 /*.mem_buffer =*/ NULL, 906 /*.no_alloc =*/ true, 907 }; 908 struct ggml_context * ctx = ggml_init(params); 909 ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); 910 if (tensor == nullptr) { 911 GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__); 912 ggml_free(ctx); 913 return false; 914 } 915 GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); 916 const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset); 917 ggml_backend_tensor_set(tensor, data, offset, size); 918 ggml_free(ctx); 919 return true; 920 } 921 922 bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { 923 // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) | 924 if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) { 925 return false; 926 } 927 const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); 928 uint64_t offset; 929 memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); 930 uint64_t size; 931 memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size)); 932 933 struct ggml_init_params params { 934 /*.mem_size =*/ ggml_tensor_overhead(), 935 /*.mem_buffer =*/ NULL, 936 /*.no_alloc =*/ true, 937 }; 938 struct ggml_context * ctx = ggml_init(params); 939 ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); 940 if (tensor == nullptr) { 941 GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__); 942 ggml_free(ctx); 943 return false; 944 } 945 GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); 946 // output serialization format: | data (size bytes) | 947 output.resize(size, 0); 948 ggml_backend_tensor_get(tensor, output.data(), offset, size); 949 ggml_free(ctx); 950 return true; 951 } 952 953 bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { 954 // serialization format: | rpc_tensor src | rpc_tensor dst | 955 if (input.size() != 2*sizeof(rpc_tensor)) { 956 return false; 957 } 958 const rpc_tensor * rpc_src = (const rpc_tensor *)input.data(); 959 const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src)); 960 961 struct ggml_init_params params { 962 /*.mem_size =*/ 2*ggml_tensor_overhead(), 963 /*.mem_buffer =*/ NULL, 964 /*.no_alloc =*/ true, 965 }; 966 struct ggml_context * ctx = ggml_init(params); 967 ggml_tensor * src = deserialize_tensor(ctx, rpc_src); 968 ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst); 969 if (src == nullptr || dst == nullptr) { 970 GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__); 971 ggml_free(ctx); 972 return false; 973 } 974 GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer); 975 bool result = ggml_backend_buffer_copy_tensor(src, dst); 976 // output serialization format: | result (1 byte) | 977 output.resize(1, 0); 978 output[0] = result; 979 ggml_free(ctx); 980 return true; 981 } 982 983 ggml_tensor * rpc_server::create_node(uint64_t id, 984 struct ggml_context * ctx, 985 const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs, 986 std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) { 987 if (id == 0) { 988 return nullptr; 989 } 990 if (tensor_map.find(id) != tensor_map.end()) { 991 return tensor_map[id]; 992 } 993 const rpc_tensor * tensor = tensor_ptrs.at(id); 994 struct ggml_tensor * result = deserialize_tensor(ctx, tensor); 995 if (result == nullptr) { 996 return nullptr; 997 } 998 tensor_map[id] = result; 999 for (int i = 0; i < GGML_MAX_SRC; i++) { 1000 result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map); 1001 } 1002 result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map); 1003 result->view_offs = tensor->view_offs; 1004 return result; 1005 } 1006 1007 bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { 1008 // serialization format: 1009 // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | 1010 if (input.size() < sizeof(uint32_t)) { 1011 return false; 1012 } 1013 uint32_t n_nodes; 1014 memcpy(&n_nodes, input.data(), sizeof(n_nodes)); 1015 if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { 1016 return false; 1017 } 1018 const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes)); 1019 uint32_t n_tensors; 1020 memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors)); 1021 if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { 1022 return false; 1023 } 1024 const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors)); 1025 GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors); 1026 1027 static size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); 1028 struct ggml_init_params params = { 1029 /*.mem_size =*/ buf_size, 1030 /*.mem_buffer =*/ NULL, 1031 /*.no_alloc =*/ true, 1032 }; 1033 struct ggml_context * ctx = ggml_init(params); 1034 struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); 1035 graph->n_nodes = n_nodes; 1036 std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs; 1037 for (uint32_t i = 0; i < n_tensors; i++) { 1038 tensor_ptrs[tensors[i].id] = &tensors[i]; 1039 } 1040 std::unordered_map<uint64_t, ggml_tensor*> tensor_map; 1041 for (uint32_t i = 0; i < n_nodes; i++) { 1042 int64_t id; 1043 memcpy(&id, &nodes[i], sizeof(id)); 1044 graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map); 1045 } 1046 ggml_status status = ggml_backend_graph_compute(backend, graph); 1047 // output serialization format: | status (1 byte) | 1048 output.resize(1, 0); 1049 output[0] = status; 1050 ggml_free(ctx); 1051 return true; 1052 } 1053 1054 rpc_server::~rpc_server() { 1055 for (auto buffer : buffers) { 1056 ggml_backend_buffer_free(buffer); 1057 } 1058 } 1059 1060 static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) { 1061 rpc_server server(backend); 1062 while (true) { 1063 uint8_t cmd; 1064 if (!recv_data(sockfd, &cmd, 1)) { 1065 break; 1066 } 1067 std::vector<uint8_t> input; 1068 std::vector<uint8_t> output; 1069 uint64_t input_size; 1070 if (!recv_data(sockfd, &input_size, sizeof(input_size))) { 1071 break; 1072 } 1073 input.resize(input_size); 1074 if (!recv_data(sockfd, input.data(), input_size)) { 1075 break; 1076 } 1077 bool ok = true; 1078 switch (cmd) { 1079 case ALLOC_BUFFER: { 1080 ok = server.alloc_buffer(input, output); 1081 break; 1082 } 1083 case GET_ALIGNMENT: { 1084 server.get_alignment(output); 1085 break; 1086 } 1087 case GET_MAX_SIZE: { 1088 server.get_max_size(output); 1089 break; 1090 } 1091 case BUFFER_GET_BASE: { 1092 ok = server.buffer_get_base(input, output); 1093 break; 1094 } 1095 case FREE_BUFFER: { 1096 ok = server.free_buffer(input); 1097 break; 1098 } 1099 case BUFFER_CLEAR: { 1100 ok = server.buffer_clear(input); 1101 break; 1102 } 1103 case SET_TENSOR: { 1104 ok = server.set_tensor(input); 1105 break; 1106 } 1107 case GET_TENSOR: { 1108 ok = server.get_tensor(input, output); 1109 break; 1110 } 1111 case COPY_TENSOR: { 1112 ok = server.copy_tensor(input, output); 1113 break; 1114 } 1115 case GRAPH_COMPUTE: { 1116 ok = server.graph_compute(input, output); 1117 break; 1118 } 1119 case GET_DEVICE_MEMORY: { 1120 // output serialization format: | free (8 bytes) | total (8 bytes) | 1121 output.resize(2*sizeof(uint64_t), 0); 1122 memcpy(output.data(), &free_mem, sizeof(free_mem)); 1123 memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem)); 1124 break; 1125 } 1126 default: { 1127 fprintf(stderr, "Unknown command: %d\n", cmd); 1128 ok = false; 1129 } 1130 } 1131 if (!ok) { 1132 break; 1133 } 1134 uint64_t output_size = output.size(); 1135 if (!send_data(sockfd, &output_size, sizeof(output_size))) { 1136 break; 1137 } 1138 if (!send_data(sockfd, output.data(), output_size)) { 1139 break; 1140 } 1141 } 1142 } 1143 1144 void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) { 1145 std::string host; 1146 int port; 1147 if (!parse_endpoint(endpoint, host, port)) { 1148 return; 1149 } 1150 #ifdef _WIN32 1151 { 1152 WSADATA wsaData; 1153 int res = WSAStartup(MAKEWORD(2, 2), &wsaData); 1154 if (res != 0) { 1155 fprintf(stderr, "WSAStartup failed: %d\n", res); 1156 return; 1157 } 1158 } 1159 #endif 1160 auto server_socket = create_server_socket(host.c_str(), port); 1161 if (server_socket == nullptr) { 1162 fprintf(stderr, "Failed to create server socket\n"); 1163 return; 1164 } 1165 while (true) { 1166 auto client_socket = socket_accept(server_socket->fd); 1167 if (client_socket == nullptr) { 1168 fprintf(stderr, "Failed to accept client connection\n"); 1169 return; 1170 } 1171 printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); 1172 rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); 1173 printf("Client connection closed\n"); 1174 } 1175 #ifdef _WIN32 1176 WSACleanup(); 1177 #endif 1178 }