/ kernel / io_uring / io_uring.c
io_uring.c
  1  #include "io_uring.h"
  2  #include "memory.h"
  3  #include "kernel.h"
  4  #include "string.h"
  5  #include "process.h"
  6  #include "fd.h"
  7  #include "fs.h"
  8  
  9  // Maximum number of io_uring instances
 10  #define MAX_IO_RINGS 16
 11  #define MAX_REGISTERED_BUFFERS 64
 12  #define MAX_REGISTERED_FILES 32
 13  
 14  // io_uring instance tracking
 15  static struct io_uring* io_rings[MAX_IO_RINGS];
 16  static uint32_t shared_memory_offset = 0;
 17  
 18  // Extended io_uring structure for registered resources
 19  struct io_uring_registered {
 20      struct io_uring ring;
 21      
 22      // Registered buffers
 23      struct {
 24          void* addr;
 25          size_t len;
 26      } registered_buffers[MAX_REGISTERED_BUFFERS];
 27      uint32_t nr_registered_buffers;
 28      
 29      // Registered files
 30      int registered_files[MAX_REGISTERED_FILES];
 31      uint32_t nr_registered_files;
 32  };
 33  
 34  // Initialize an io_uring instance
 35  static int init_io_uring(struct io_uring* ring, uint32_t entries) {
 36      // Calculate ring sizes
 37      uint32_t sq_entries = entries;
 38      uint32_t cq_entries = entries * 2; // CQ is typically larger
 39      
 40      // Calculate memory sizes
 41      size_t sq_ring_size = sq_entries * sizeof(uint32_t);
 42      size_t cq_ring_size = cq_entries * sizeof(struct io_uring_cqe);
 43      size_t sqes_size = sq_entries * sizeof(struct io_uring_sqe);
 44      
 45      // Check if we have enough shared memory
 46      size_t total_size = PAGE_SIZE + sq_ring_size + PAGE_SIZE + cq_ring_size + sqes_size;
 47      if (shared_memory_offset + total_size > IORING_MAX_SHARED_SIZE) {
 48          return -1; // Out of shared memory
 49      }
 50      
 51      // Allocate memory for SQ ring from shared region
 52      uint32_t base_addr = IORING_SHARED_MEMORY_BASE + shared_memory_offset;
 53      void* sq_ring = (void*)base_addr;
 54      ring->sq_ring_addr = base_addr;
 55      
 56      // Set up SQ pointers
 57      ring->sq_head = (uint32_t*)sq_ring;
 58      ring->sq_tail = ring->sq_head + 1;
 59      ring->sq_mask = ring->sq_tail + 1;
 60      ring->sq_entries = ring->sq_mask + 1;
 61      ring->sq_flags = ring->sq_entries + 1;
 62      ring->sq_dropped = ring->sq_flags + 1;
 63      ring->sq_array = ring->sq_dropped + 1;
 64      
 65      // Initialize SQ values
 66      *ring->sq_head = 0;
 67      *ring->sq_tail = 0;
 68      *ring->sq_mask = sq_entries - 1;
 69      *ring->sq_entries = sq_entries;
 70      *ring->sq_flags = 0;
 71      *ring->sq_dropped = 0;
 72      
 73      // Allocate SQEs from shared memory
 74      ring->sqes_addr = base_addr + PAGE_SIZE + sq_ring_size;
 75      ring->sqes = (struct io_uring_sqe*)ring->sqes_addr;
 76      memset(ring->sqes, 0, sqes_size);
 77      
 78      // Allocate memory for CQ ring from shared memory
 79      ring->cq_ring_addr = ring->sqes_addr + sqes_size;
 80      void* cq_ring = (void*)ring->cq_ring_addr;
 81      
 82      // Set up CQ pointers
 83      ring->cq_head = (uint32_t*)cq_ring;
 84      ring->cq_tail = ring->cq_head + 1;
 85      ring->cq_mask = ring->cq_tail + 1;
 86      ring->cq_entries = ring->cq_mask + 1;
 87      ring->cq_overflow = ring->cq_entries + 1;
 88      ring->cqes = (struct io_uring_cqe*)((uint8_t*)cq_ring + PAGE_SIZE);
 89      
 90      // Initialize CQ values
 91      *ring->cq_head = 0;
 92      *ring->cq_tail = 0;
 93      *ring->cq_mask = cq_entries - 1;
 94      *ring->cq_entries = cq_entries;
 95      *ring->cq_overflow = 0;
 96      
 97      // Set up parameters
 98      ring->params.sq_entries = sq_entries;
 99      ring->params.cq_entries = cq_entries;
100      ring->params.flags = 0;
101      
102      // Set up offsets
103      ring->params.sq_off.head = 0;
104      ring->params.sq_off.tail = sizeof(uint32_t);
105      ring->params.sq_off.ring_mask = sizeof(uint32_t) * 2;
106      ring->params.sq_off.ring_entries = sizeof(uint32_t) * 3;
107      ring->params.sq_off.flags = sizeof(uint32_t) * 4;
108      ring->params.sq_off.dropped = sizeof(uint32_t) * 5;
109      ring->params.sq_off.array = sizeof(uint32_t) * 6;
110      
111      ring->params.cq_off.head = 0;
112      ring->params.cq_off.tail = sizeof(uint32_t);
113      ring->params.cq_off.ring_mask = sizeof(uint32_t) * 2;
114      ring->params.cq_off.ring_entries = sizeof(uint32_t) * 3;
115      ring->params.cq_off.overflow = sizeof(uint32_t) * 4;
116      ring->params.cq_off.cqes = PAGE_SIZE;
117      
118      // Store sizes
119      ring->sq_ring_size = sq_ring_size + PAGE_SIZE;
120      ring->cq_ring_size = cq_ring_size + PAGE_SIZE;
121      ring->sqes_size = sqes_size;
122      
123      // Update shared memory offset
124      shared_memory_offset += total_size;
125      
126      // Make addresses available in params for userspace
127      ring->params.resv[0] = ring->sq_ring_addr;
128      ring->params.resv[1] = ring->cq_ring_addr;
129      ring->params.resv[2] = ring->sqes_addr;
130      
131      return 0;
132  }
133  
134  // io_uring_setup system call
135  int io_uring_setup(uint32_t entries, struct io_uring_params *params) {
136      // Validate entries (must be power of 2)
137      if (entries == 0 || (entries & (entries - 1)) != 0) {
138          return -1; // -EINVAL
139      }
140      
141      // Find free fd slot
142      int fd = -1;
143      for (int i = 0; i < MAX_IO_RINGS; i++) {
144          if (io_rings[i] == NULL) {
145              fd = i;
146              break;
147          }
148      }
149      
150      if (fd == -1) {
151          return -1; // -EMFILE
152      }
153      
154      // Allocate io_uring structure
155      struct io_uring_registered* ring = (struct io_uring_registered*)kmalloc(sizeof(struct io_uring_registered));
156      if (!ring) {
157          return -1; // -ENOMEM
158      }
159      
160      // Clear the structure
161      memset(ring, 0, sizeof(struct io_uring_registered));
162      
163      // Initialize the ring
164      if (init_io_uring(&ring->ring, entries) < 0) {
165          kfree(ring);
166          return -1;
167      }
168      
169      // Copy parameters back to user
170      if (params) {
171          memcpy(params, &ring->ring.params, sizeof(struct io_uring_params));
172      }
173      
174      // Store ring
175      io_rings[fd] = &ring->ring;
176      
177      return fd;
178  }
179  
180  // Check if any blocked processes can be unblocked
181  static void check_blocked_processes(int ring_fd) {
182      struct io_uring* ring = io_rings[ring_fd];
183      if (!ring) return;
184      
185      // Calculate current completions
186      uint32_t cq_head = *ring->cq_head;
187      uint32_t cq_tail = *ring->cq_tail;
188      uint32_t completed = 0;
189      
190      if (cq_tail >= cq_head) {
191          completed = cq_tail - cq_head;
192      } else {
193          completed = (*ring->cq_entries - cq_head) + cq_tail;
194      }
195      
196      // Check all processes
197      struct process* proc = process_list;
198      while (proc) {
199          if (proc->state == PROCESS_BLOCKED && 
200              proc->blocked_on_ring_fd == ring_fd &&
201              completed >= proc->min_complete) {
202              process_unblock(proc);
203          }
204          proc = proc->next;
205      }
206  }
207  
208  // Process a submission queue entry
209  static void process_sqe(struct io_uring* ring, struct io_uring_sqe* sqe, int ring_fd) {
210      struct io_uring_cqe* cqe;
211      uint32_t tail = *ring->cq_tail;
212      uint32_t next_tail = (tail + 1) & *ring->cq_mask;
213      
214      // Debug output disabled to prevent console spam
215      // terminal_writestring("io_uring: Processing SQE opcode=");
216      // if (sqe->opcode == IORING_OP_NOP) terminal_writestring("NOP");
217      // else if (sqe->opcode == IORING_OP_WRITE) terminal_writestring("WRITE");
218      // else if (sqe->opcode == IORING_OP_READ) terminal_writestring("READ");
219      // else terminal_writestring("UNKNOWN");
220      // terminal_writestring("\n");
221      
222      // Check if CQ is full
223      if (next_tail == *ring->cq_head) {
224          (*ring->cq_overflow)++;
225          terminal_writestring("io_uring: CQ overflow!\n");
226          return;
227      }
228      
229      // Get CQE slot
230      cqe = &ring->cqes[tail];
231      
232      // Process based on opcode
233      switch (sqe->opcode) {
234          case IORING_OP_NOP:
235              // No operation - just complete
236              cqe->res = 0;
237              break;
238              
239          case IORING_OP_READ:
240              // Implement read operation using file descriptors
241              {
242                  struct file_descriptor* file = fd_get(sqe->fd);
243                  if (file && file->read) {
244                      void* buf = (void*)(uintptr_t)sqe->addr;
245                      cqe->res = file->read(file, buf, sqe->len);
246                  } else {
247                      cqe->res = -1; // -EBADF
248                  }
249              }
250              break;
251              
252          case IORING_OP_WRITE:
253              // Implement write operation using file descriptors
254              {
255                  struct file_descriptor* file = fd_get(sqe->fd);
256                  if (file && file->write) {
257                      const void* buf = (const void*)(uintptr_t)sqe->addr;
258                      cqe->res = file->write(file, buf, sqe->len);
259                  } else {
260                      cqe->res = -1; // -EBADF
261                  }
262              }
263              break;
264              
265          case IORING_OP_OPENAT:
266              // Open file (simplified - only support opening files from initrd)
267              {
268                  const char* path = (const char*)(uintptr_t)sqe->addr;
269                  struct fs_node* node = fs_root->finddir(fs_root, (char*)path);
270                  
271                  if (node && (node->flags & FS_FILE)) {
272                      // Create file descriptor
273                      struct file_descriptor* new_fd = fd_create_for_file(node);
274                      if (new_fd) {
275                          int fd_num = fd_allocate(new_fd);
276                          cqe->res = fd_num;
277                      } else {
278                          cqe->res = -1; // -ENOMEM
279                      }
280                  } else {
281                      cqe->res = -1; // -ENOENT
282                  }
283              }
284              break;
285              
286          case IORING_OP_CLOSE:
287              // Close file descriptor
288              {
289                  fd_close(sqe->fd);
290                  cqe->res = 0;
291              }
292              break;
293              
294          case IORING_OP_READV:
295              // Vectored read
296              {
297                  struct file_descriptor* file = fd_get(sqe->fd);
298                  if (file && file->read) {
299                      struct iovec {
300                          void* iov_base;
301                          size_t iov_len;
302                      } *iovecs = (struct iovec*)(uintptr_t)sqe->addr;
303                      uint32_t nr_vecs = sqe->len;
304                      int32_t total = 0;
305                      
306                      for (uint32_t i = 0; i < nr_vecs; i++) {
307                          int ret = file->read(file, iovecs[i].iov_base, iovecs[i].iov_len);
308                          if (ret < 0) {
309                              cqe->res = ret;
310                              goto done;
311                          }
312                          total += ret;
313                          if (ret < (int32_t)iovecs[i].iov_len) {
314                              break; // Short read
315                          }
316                      }
317                      cqe->res = total;
318                  } else {
319                      cqe->res = -1; // -EBADF
320                  }
321              }
322              break;
323              
324          case IORING_OP_WRITEV:
325              // Vectored write
326              {
327                  struct file_descriptor* file = fd_get(sqe->fd);
328                  if (file && file->write) {
329                      struct iovec {
330                          void* iov_base;
331                          size_t iov_len;
332                      } *iovecs = (struct iovec*)(uintptr_t)sqe->addr;
333                      uint32_t nr_vecs = sqe->len;
334                      int32_t total = 0;
335                      
336                      for (uint32_t i = 0; i < nr_vecs; i++) {
337                          int ret = file->write(file, iovecs[i].iov_base, iovecs[i].iov_len);
338                          if (ret < 0) {
339                              cqe->res = ret;
340                              goto done;
341                          }
342                          total += ret;
343                          if (ret < (int32_t)iovecs[i].iov_len) {
344                              break; // Short write
345                          }
346                      }
347                      cqe->res = total;
348                  } else {
349                      cqe->res = -1; // -EBADF
350                  }
351              }
352              break;
353              
354          case IORING_OP_CONNECT:
355              {
356                  int child_pid = process_fork();
357                  cqe->res = child_pid;
358              }
359              break;
360              
361          case IORING_OP_FORK:
362              {
363                  int pid = process_fork();
364                  cqe->res = pid >= 0 ? pid : -1;
365              }
366              break;
367              
368          default:
369              terminal_writestring("io_uring: Unknown operation\n");
370              cqe->res = -1; // -EINVAL
371              break;
372      }
373      
374  done:
375      // Copy user data
376      cqe->user_data = sqe->user_data;
377      cqe->flags = 0;
378      
379      // Update tail
380      *ring->cq_tail = next_tail;
381      
382      // Check if any processes can be unblocked
383      check_blocked_processes(ring_fd);
384  }
385  
386  // io_uring_enter system call
387  int io_uring_enter(int fd, uint32_t to_submit, uint32_t min_complete, 
388                     uint32_t flags, void *sig) {
389      // Unused parameters (for future implementation)
390      (void)flags;
391      (void)sig;
392      
393      // Validate fd
394      if (fd < 0 || fd >= MAX_IO_RINGS || !io_rings[fd]) {
395          return -1; // -EBADF
396      }
397      
398      struct io_uring* ring = io_rings[fd];
399      uint32_t submitted = 0;
400      uint32_t completed = 0;
401      
402      // Submit SQEs
403      if (to_submit > 0) {
404          // Memory barrier to ensure we see updates from userspace
405          __asm__ volatile("mfence" ::: "memory");
406          
407          uint32_t head = *ring->sq_head;
408          uint32_t tail = *ring->sq_tail;
409          
410          while (submitted < to_submit && head != tail) {
411              uint32_t index = ring->sq_array[head & *ring->sq_mask];
412              
413              // Bounds check
414              if (index >= *ring->sq_entries) {
415                  head++;
416                  continue;
417              }
418              
419              struct io_uring_sqe* sqe = &ring->sqes[index];
420              
421              process_sqe(ring, sqe, fd);
422              
423              head++;
424              submitted++;
425          }
426          
427          // Update head with memory barrier
428          __asm__ volatile("mfence" ::: "memory");
429          *ring->sq_head = head;
430      }
431      
432      // Check completions - do this after processing submissions
433      __asm__ volatile("mfence" ::: "memory");
434      uint32_t cq_head = *ring->cq_head;
435      uint32_t cq_tail = *ring->cq_tail;
436      
437      // Calculate completed count correctly - no mask needed for counting
438      if (cq_tail >= cq_head) {
439          completed = cq_tail - cq_head;
440      } else {
441          // Handle wraparound
442          completed = (*ring->cq_entries - cq_head) + cq_tail;
443      }
444      
445      // Debug output disabled
446      // if (min_complete > 0) {
447      //     terminal_writestring("io_uring_enter: waiting for completions, have ");
448      //     char buf[16];
449      //     itoa(completed, buf, 10);
450      //     terminal_writestring(buf);
451      //     terminal_writestring(" need ");
452      //     itoa(min_complete, buf, 10);
453      //     terminal_writestring(buf);
454      //     terminal_writestring("\n");
455      // }
456      
457      // Wait for min_complete if necessary
458      if (min_complete > 0 && completed < min_complete) {
459          // Block the current process
460          // terminal_writestring("io_uring_enter: blocking process\n");
461          process_block(fd, min_complete);
462          
463          // When we return here, we've been unblocked
464          // Recalculate completions
465          __asm__ volatile("mfence" ::: "memory");
466          cq_head = *ring->cq_head;
467          cq_tail = *ring->cq_tail;
468          
469          if (cq_tail >= cq_head) {
470              completed = cq_tail - cq_head;
471          } else {
472              completed = (*ring->cq_entries - cq_head) + cq_tail;
473          }
474          
475          return completed;
476      }
477      
478      return submitted;
479  }
480  
481  // io_uring_register system call
482  int io_uring_register(int fd, uint32_t opcode, void *arg, uint32_t nr_args) {
483      // Validate fd
484      if (fd < 0 || fd >= MAX_IO_RINGS || !io_rings[fd]) {
485          return -1; // -EBADF
486      }
487      
488      // Get the extended ring structure
489      struct io_uring_registered* ring = (struct io_uring_registered*)
490          ((char*)io_rings[fd] - offsetof(struct io_uring_registered, ring));
491      
492      switch (opcode) {
493          case IORING_REGISTER_BUFFERS:
494              {
495                  if (ring->nr_registered_buffers > 0) {
496                      return -1; // -EBUSY
497                  }
498                  
499                  if (nr_args > MAX_REGISTERED_BUFFERS) {
500                      return -1; // -EINVAL
501                  }
502                  
503                  struct iovec {
504                      void* iov_base;
505                      size_t iov_len;
506                  } *iovecs = (struct iovec*)arg;
507                  
508                  // Register buffers
509                  for (uint32_t i = 0; i < nr_args; i++) {
510                      ring->registered_buffers[i].addr = iovecs[i].iov_base;
511                      ring->registered_buffers[i].len = iovecs[i].iov_len;
512                  }
513                  ring->nr_registered_buffers = nr_args;
514                  
515                  terminal_writestring("io_uring: Registered ");
516                  char buf[16];
517                  itoa(nr_args, buf, 10);
518                  terminal_writestring(buf);
519                  terminal_writestring(" buffers\n");
520                  return 0;
521              }
522              
523          case IORING_UNREGISTER_BUFFERS:
524              ring->nr_registered_buffers = 0;
525              terminal_writestring("io_uring: Unregistered buffers\n");
526              return 0;
527              
528          case IORING_REGISTER_FILES:
529              {
530                  if (ring->nr_registered_files > 0) {
531                      return -1; // -EBUSY
532                  }
533                  
534                  if (nr_args > MAX_REGISTERED_FILES) {
535                      return -1; // -EINVAL
536                  }
537                  
538                  int* fds = (int*)arg;
539                  
540                  // Register files
541                  for (uint32_t i = 0; i < nr_args; i++) {
542                      ring->registered_files[i] = fds[i];
543                  }
544                  ring->nr_registered_files = nr_args;
545                  
546                  terminal_writestring("io_uring: Registered ");
547                  char buf[16];
548                  itoa(nr_args, buf, 10);
549                  terminal_writestring(buf);
550                  terminal_writestring(" files\n");
551                  return 0;
552              }
553              
554          case IORING_UNREGISTER_FILES:
555              ring->nr_registered_files = 0;
556              terminal_writestring("io_uring: Unregistered files\n");
557              return 0;
558              
559          default:
560              terminal_writestring("io_uring_register: Unsupported opcode\n");
561              return -1; // -EINVAL
562      }
563  }
564  
565  // Initialize io_uring subsystem
566  void io_uring_init(void) {
567      // Clear io_rings array
568      for (int i = 0; i < MAX_IO_RINGS; i++) {
569          io_rings[i] = NULL;
570      }
571      
572      terminal_writestring("io_uring subsystem initialized\n");
573  }