/ ggml-kompute.cpp
ggml-kompute.cpp
   1  #include "ggml.h"
   2  #include "ggml-backend.h"
   3  #include "ggml-backend-impl.h"
   4  #include "ggml-kompute.h"
   5  
   6  // These are generated at build time by cmake custom command
   7  #include "shaderop_scale.h"
   8  #include "shaderop_scale_8.h"
   9  #include "shaderop_add.h"
  10  #include "shaderop_addrow.h"
  11  #include "shaderop_mul.h"
  12  #include "shaderop_silu.h"
  13  #include "shaderop_relu.h"
  14  #include "shaderop_gelu.h"
  15  #include "shaderop_softmax.h"
  16  #include "shaderop_norm.h"
  17  #include "shaderop_rmsnorm.h"
  18  #include "shaderop_diagmask.h"
  19  #include "shaderop_mul_mat_f16.h"
  20  #include "shaderop_mul_mat_q8_0.h"
  21  #include "shaderop_mul_mat_q4_0.h"
  22  #include "shaderop_mul_mat_q4_1.h"
  23  #include "shaderop_mul_mat_q6_k.h"
  24  #include "shaderop_mul_mat_mat_f32.h"
  25  #include "shaderop_getrows_f32.h"
  26  #include "shaderop_getrows_f16.h"
  27  #include "shaderop_getrows_q4_0.h"
  28  #include "shaderop_getrows_q4_1.h"
  29  #include "shaderop_getrows_q6_k.h"
  30  #include "shaderop_rope_f16.h"
  31  #include "shaderop_rope_f32.h"
  32  #include "shaderop_cpy_f16_f16.h"
  33  #include "shaderop_cpy_f16_f32.h"
  34  #include "shaderop_cpy_f32_f16.h"
  35  #include "shaderop_cpy_f32_f32.h"
  36  
  37  #include <algorithm>
  38  #include <array>
  39  #include <cassert>
  40  #include <cstdint>
  41  #include <cstdio>
  42  #include <cstring>
  43  #include <iostream>
  44  #include <memory>
  45  #include <stdexcept>
  46  #include <string>
  47  #include <unordered_map>
  48  #include <utility>
  49  #include <vector>
  50  
  51  #include <kompute/Kompute.hpp>
  52  #include <vulkan/vulkan.hpp>
  53  
  54  #ifdef __linux__
  55  #include <cstdlib> // for setenv
  56  #endif
  57  
  58  #define QK4_0 32
  59  #define QR4_0 2
  60  #define QK4_1 32
  61  #define QK_NL 16
  62  
  63  typedef ggml_fp16_t half;
  64  
  65  static std::string ggml_kompute_format_name(int device) {
  66      return "Kompute" + std::to_string(device);
  67  }
  68  
  69  struct ggml_kompute_context {
  70      int device;
  71      std::string name;
  72      std::shared_ptr<vk::DescriptorPool> pool;
  73  
  74      ggml_kompute_context(int device)
  75          : device(device), name(ggml_kompute_format_name(device)) {}
  76  };
  77  
  78  // FIXME: It would be good to consolidate the kompute manager and the kompute context into one object
  79  // and consolidate the init functions and simplify object lifetime management. As it currently stands,
  80  // we *have* to have the kompute manager no matter what for device discovery, but the kompute context
  81  // is only created when a device is set and vulkan is explicitly turned on.
  82  static ggml_kompute_context *s_kompute_context = nullptr;
  83  
  84  class kompute_manager {
  85      kp::Manager *s_mgr = nullptr;
  86  
  87  public:
  88      kp::Manager *operator()() {
  89          if (s_mgr && !s_mgr->hasInstance()) {
  90              destroy();
  91          }
  92          if (!s_mgr) {
  93              s_mgr = new kp::Manager;
  94          }
  95          return s_mgr;
  96      }
  97  
  98      void destroy() {
  99          delete s_mgr;
 100          s_mgr = nullptr;
 101      }
 102  };
 103  
 104  static kompute_manager komputeManager;
 105  
 106  struct ggml_vk_memory {
 107      void *data = nullptr;
 108      size_t size = 0;
 109      vk::DeviceMemory *primaryMemory = nullptr;
 110      vk::Buffer *primaryBuffer = nullptr;
 111      vk::DeviceMemory *stagingMemory = nullptr;
 112      vk::Buffer *stagingBuffer = nullptr;
 113  };
 114  
 115  #ifdef __linux__
 116  __attribute__((constructor))
 117  static void enable_sam() {
 118      setenv("RADV_PERFTEST", "sam", false);
 119  }
 120  #endif
 121  
 122  static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) {
 123      vk::PhysicalDeviceFeatures availableFeatures;
 124      physical_device.getFeatures(&availableFeatures);
 125  
 126      if (!availableFeatures.shaderInt16)
 127          return false;
 128  
 129      vk::PhysicalDeviceVulkan11Features availableFeatures11;
 130      vk::PhysicalDeviceVulkan12Features availableFeatures12;
 131  
 132      availableFeatures11.pNext = &availableFeatures12;
 133      availableFeatures12.pNext = nullptr;
 134  
 135      vk::PhysicalDeviceFeatures2 features2;
 136      features2.pNext = &availableFeatures11;
 137  
 138      physical_device.getFeatures2(&features2);
 139  
 140      if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
 141          !availableFeatures11.storageBuffer16BitAccess) {
 142          return false;
 143      }
 144  
 145      if (!availableFeatures12.storageBuffer8BitAccess ||
 146          !availableFeatures12.uniformAndStorageBuffer8BitAccess ||
 147          !availableFeatures12.shaderFloat16 ||
 148          !availableFeatures12.shaderInt8) {
 149          return false;
 150      }
 151  
 152      return true;
 153  }
 154  
 155  static const char * ggml_vk_getVendorName(uint32_t vendorID) {
 156      switch (vendorID) {
 157          case 0x10DE:
 158              return "nvidia";
 159          case 0x1002:
 160              return "amd";
 161          case 0x8086:
 162              return "intel";
 163          default:
 164              return "unknown";
 165      }
 166  }
 167  
 168  static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t memoryRequired) {
 169      std::vector<ggml_vk_device> results;
 170      if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance())
 171          return results;
 172  
 173      std::vector<vk::PhysicalDevice> physical_devices;
 174      try {
 175          physical_devices = komputeManager()->listDevices();
 176      } catch (vk::SystemError & err) {
 177          std::cerr << __func__ << ": ignoring Vulkan exception: " << err.what() << "\n";
 178          return results;
 179      }
 180  
 181      uint32_t deviceCount = physical_devices.size();
 182      if (deviceCount == 0)
 183          return results;
 184  
 185      std::unordered_map<std::string, size_t> count_by_name;
 186  
 187      for (uint32_t i = 0; i < deviceCount; i++) {
 188          const auto & physical_device = physical_devices[i];
 189  
 190          VkPhysicalDeviceProperties dev_props = physical_device.getProperties();
 191          VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties();
 192          const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion);
 193          const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion);
 194          if (major < 1 || minor < 2)
 195              continue;
 196  
 197          if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device))
 198              continue;
 199  
 200          size_t heapSize = 0;
 201          for (uint32_t j = 0; j < memoryProperties.memoryHeapCount; ++j) {
 202              VkMemoryHeap heap = memoryProperties.memoryHeaps[j];
 203              if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) {
 204                  heapSize = heap.size;
 205                  break;
 206              }
 207          }
 208  
 209          if (heapSize < memoryRequired)
 210              continue;
 211  
 212          auto ext_props = physical_device.enumerateDeviceExtensionProperties();
 213          bool has_maintenance4 = false;
 214  
 215          // Check if maintenance4 is supported
 216          for (const auto & properties : ext_props) {
 217              if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
 218                  has_maintenance4 = true;
 219              }
 220          }
 221  
 222          vk::PhysicalDeviceSubgroupProperties subgroup_props;
 223          vk::PhysicalDeviceProperties2 dev_props2;
 224          vk::PhysicalDeviceMaintenance3Properties dev_props3;
 225          vk::PhysicalDeviceMaintenance4Properties dev_props4;
 226          dev_props2.pNext = &dev_props3;
 227          dev_props3.pNext = &subgroup_props;
 228          if (has_maintenance4) {
 229              subgroup_props.pNext = &dev_props4;
 230          }
 231          physical_device.getProperties2(&dev_props2);
 232  
 233          if (subgroup_props.subgroupSize < 32)
 234              continue;
 235  
 236          ggml_vk_device d;
 237          d.index = i;
 238          d.type = dev_props.deviceType;
 239          d.heapSize = heapSize;
 240          d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID));
 241          d.subgroupSize = subgroup_props.subgroupSize;
 242          d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment;
 243  
 244          if (has_maintenance4) {
 245              d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize);
 246          } else {
 247              d.maxAlloc = dev_props3.maxMemoryAllocationSize;
 248          }
 249  
 250          std::string name(dev_props.deviceName);
 251          size_t n_idx = ++count_by_name[name];
 252          if (n_idx > 1) {
 253              name += " (" + std::to_string(n_idx) + ")";
 254          }
 255          d.name = strdup(name.c_str());
 256  
 257          results.push_back(d);
 258      }
 259  
 260      std::stable_sort(results.begin(), results.end(),
 261          [](const ggml_vk_device& lhs, const ggml_vk_device& rhs) -> bool {
 262              if (lhs.type != rhs.type) {
 263                  if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true;
 264                  if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false;
 265  
 266                  if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true;
 267                  if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false;
 268              }
 269              return lhs.heapSize < rhs.heapSize;
 270          }
 271      );
 272  
 273      return results;
 274  }
 275  
 276  // public API returns a C-style array
 277  ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count) {
 278      auto devices = ggml_vk_available_devices_internal(memoryRequired);
 279      *count = devices.size();
 280      if (devices.empty()) {
 281          return nullptr;
 282      }
 283  
 284      size_t nbytes = sizeof (ggml_vk_device) * (devices.size());
 285      auto * arr = static_cast<ggml_vk_device *>(malloc(nbytes));
 286      memcpy(arr, devices.data(), nbytes);
 287      return arr;
 288  }
 289  
 290  static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
 291      devices.erase(
 292          std::remove_if(devices.begin(), devices.end(),
 293              [&targetVendor](const ggml_vk_device& device) {
 294                  return device.vendor != targetVendor;
 295              }),
 296          devices.end()
 297      );
 298  }
 299  
 300  static void ggml_vk_filterByName(std::vector<ggml_vk_device>& devices, const std::string& targetName) {
 301      devices.erase(
 302          std::remove_if(devices.begin(), devices.end(),
 303              [&targetName](const ggml_vk_device& device) {
 304                  return device.name != targetName;
 305              }),
 306          devices.end()
 307      );
 308  }
 309  
 310  static bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const std::string & name) {
 311      if (name.empty())
 312          return false;
 313  
 314      auto devices = ggml_vk_available_devices_internal(memoryRequired);
 315      if (name == "amd" || name == "nvidia" || name == "intel") {
 316          ggml_vk_filterByVendor(devices, name);
 317      } else if (name != "gpu") {
 318          ggml_vk_filterByName(devices, name);
 319      }
 320  
 321      if (devices.empty())
 322          return false;
 323  
 324      *device = devices.front();
 325      return true;
 326  }
 327  
 328  bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) {
 329      return ggml_vk_get_device(device, memoryRequired, std::string(name));
 330  }
 331  
 332  bool ggml_vk_has_vulkan() {
 333      return komputeManager()->hasVulkan();
 334  }
 335  
 336  bool ggml_vk_has_device() {
 337      return komputeManager()->hasDevice();
 338  }
 339  
 340  ggml_vk_device ggml_vk_current_device() {
 341      if (!komputeManager()->hasDevice())
 342          return ggml_vk_device();
 343  
 344      auto devices = ggml_vk_available_devices_internal(0);
 345      ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
 346      GGML_ASSERT(!devices.empty());
 347      return devices.front();
 348  }
 349  
 350  static
 351  void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t size) {
 352      std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
 353          vk::DescriptorPoolSize(
 354            vk::DescriptorType::eStorageBuffer,
 355            3 * size // Descriptor count is number of possible tensors to pass into an algorithm
 356            )
 357      };
 358  
 359      vk::DescriptorPoolCreateInfo descriptorPoolInfo(
 360        vk::DescriptorPoolCreateFlags(),
 361        size, // Max sets
 362        static_cast<uint32_t>(descriptorPoolSizes.size()),
 363        descriptorPoolSizes.data());
 364  
 365      ctx->pool = std::make_shared<vk::DescriptorPool>();
 366      vk::Result r = komputeManager()->device()->createDescriptorPool(
 367        &descriptorPoolInfo, nullptr, ctx->pool.get());
 368      if (r != vk::Result::eSuccess)
 369          std::cerr << "Error allocating descriptor pool" << vk::to_string(r);
 370  }
 371  
 372  static
 373  void ggml_vk_free_descriptor_pool(struct ggml_kompute_context * ctx) {
 374      if (ctx->pool) {
 375          komputeManager()->device()->destroy(
 376            *ctx->pool,
 377            (vk::Optional<const vk::AllocationCallbacks>)nullptr);
 378          ctx->pool = nullptr;
 379      }
 380  }
 381  
 382  static
 383  vk::Buffer *ggml_vk_allocate_buffer(size_t size) {
 384      vk::BufferCreateInfo bufferCreateInfo;
 385      bufferCreateInfo.size = size;
 386      bufferCreateInfo.usage = vk::BufferUsageFlagBits::eStorageBuffer |
 387                               vk::BufferUsageFlagBits::eTransferSrc |
 388                               vk::BufferUsageFlagBits::eTransferDst;
 389      bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive;
 390  
 391      vk::Buffer *vkBuffer = new vk::Buffer;
 392      vk::Result r = komputeManager()->device()->createBuffer(&bufferCreateInfo, nullptr, vkBuffer);
 393      if (r != vk::Result::eSuccess)
 394          std::cerr << "Error allocating buffer " << vk::to_string(r) << std::endl;
 395      return vkBuffer;
 396  }
 397  
 398  static
 399  vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, vk::MemoryRequirements requirements, bool *isHostVisible) {
 400  
 401      uint32_t memoryTypeIndex = -1;
 402      bool memoryTypeIndexFound = false;
 403      vk::PhysicalDeviceMemoryProperties memoryProperties = komputeManager()->physicalDevice()->getMemoryProperties();
 404      for (uint32_t i = 0; i < memoryProperties.memoryTypeCount; i++) {
 405          const vk::MemoryType &memoryType = memoryProperties.memoryTypes[i];
 406          const vk::MemoryHeap &memoryHeap = memoryProperties.memoryHeaps[memoryType.heapIndex];
 407          if (memoryHeap.size < size) {
 408              continue;
 409          }
 410  
 411          if (requirements.memoryTypeBits & (1 << i)) {
 412              if (((memoryProperties.memoryTypes[i]).propertyFlags &
 413                   flags) == flags) {
 414                  memoryTypeIndex = i;
 415                  memoryTypeIndexFound = true;
 416                  if (isHostVisible && (memoryProperties.memoryTypes[i].propertyFlags & vk::MemoryPropertyFlagBits::eHostVisible)) {
 417                      *isHostVisible = true;
 418                  }
 419                  break;
 420              }
 421          }
 422      }
 423      if (!memoryTypeIndexFound) {
 424          throw std::runtime_error(
 425            "Memory type index for buffer creation not found");
 426      }
 427  
 428      vk::MemoryAllocateInfo allocInfo;
 429      allocInfo.allocationSize = size;
 430      allocInfo.memoryTypeIndex = memoryTypeIndex;
 431      vk::DeviceMemory *vkDeviceMemory =  new vk::DeviceMemory;
 432      vk::Result r = komputeManager()->device()->allocateMemory(&allocInfo, nullptr, vkDeviceMemory);
 433      if (r != vk::Result::eSuccess) {
 434          std::cerr << "Error allocating memory " << vk::to_string(r) << std::endl;
 435          throw std::runtime_error("Error allocating vulkan memory.");
 436      }
 437      return vkDeviceMemory;
 438  }
 439  
 440  static size_t ggml_vk_aligned_offset(ggml_backend_buffer_t buffer, size_t offset) {
 441      size_t minStorageBufferOffsetAlignment = ggml_backend_buffer_get_alignment(buffer);
 442  
 443      // If offset is already aligned, return it directly
 444      if (offset % minStorageBufferOffsetAlignment == 0) {
 445          return offset;
 446      }
 447  
 448      // Otherwise, return the largest multiple of minStorageBufferOffsetAlignment less than offset
 449      return (offset / minStorageBufferOffsetAlignment) * minStorageBufferOffsetAlignment;
 450  }
 451  
 452  static ggml_vk_memory ggml_vk_allocate(size_t size) {
 453      ggml_vk_memory memory;
 454      bool isHostVisible = false;
 455      {
 456          memory.primaryBuffer = ggml_vk_allocate_buffer(size);
 457          vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.primaryBuffer);
 458          vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eDeviceLocal;
 459          memory.primaryMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
 460          komputeManager()->device()->bindBufferMemory(*memory.primaryBuffer, *memory.primaryMemory, 0);
 461          if (isHostVisible) {
 462              vk::Result r = komputeManager()->device()->mapMemory(*memory.primaryMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
 463              if (r != vk::Result::eSuccess)
 464                  std::cerr << "Error mapping memory" << vk::to_string(r);
 465          }
 466      }
 467  
 468      if (!isHostVisible) {
 469          memory.stagingBuffer = ggml_vk_allocate_buffer(size);
 470          vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.stagingBuffer);
 471          vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eHostVisible |
 472                                                        vk::MemoryPropertyFlagBits::eHostCoherent |
 473                                                        vk::MemoryPropertyFlagBits::eHostCached;
 474          memory.stagingMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
 475          komputeManager()->device()->bindBufferMemory(*memory.stagingBuffer, *memory.stagingMemory, 0);
 476          vk::Result r = komputeManager()->device()->mapMemory(*memory.stagingMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
 477          if (r != vk::Result::eSuccess)
 478              std::cerr << "Error mapping memory" << vk::to_string(r);
 479      }
 480  
 481      memory.size = size;
 482      return memory;
 483  }
 484  
 485  static void ggml_vk_free_memory(ggml_vk_memory &memory)
 486  {
 487      komputeManager()->device()->destroy(
 488        *memory.primaryBuffer,
 489        (vk::Optional<const vk::AllocationCallbacks>)nullptr);
 490      if (memory.stagingBuffer) {
 491          komputeManager()->device()->destroy(
 492            *memory.stagingBuffer,
 493            (vk::Optional<const vk::AllocationCallbacks>)nullptr);
 494      }
 495      komputeManager()->device()->freeMemory(
 496        *memory.primaryMemory,
 497        (vk::Optional<const vk::AllocationCallbacks>)nullptr);
 498      if (memory.stagingMemory) {
 499          komputeManager()->device()->freeMemory(
 500            *memory.stagingMemory,
 501            (vk::Optional<const vk::AllocationCallbacks>)nullptr);
 502      }
 503  }
 504  
 505  static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft);
 506  
 507  static
 508  ggml_vk_memory * ggml_vk_find_tensor(const struct ggml_tensor * t, uint64_t & offset) {
 509      ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
 510  
 511      // compatibility with ggml-backend
 512      GGML_ASSERT(buffer && buffer->buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name);
 513  
 514      ggml_vk_memory * buf_ctx = static_cast<ggml_vk_memory *>(buffer->context);
 515  
 516      const intptr_t ioffs = intptr_t(t->data) - intptr_t(buf_ctx->data);
 517  
 518      GGML_ASSERT(ioffs >= 0 && ioffs + int64_t(ggml_nbytes(t)) <= int64_t(buffer->size));
 519  
 520      offset = uint64_t(ioffs);
 521      return buf_ctx;
 522  }
 523  
 524  static
 525  const std::shared_ptr<kp::Tensor> ggml_vk_get_tensor(const struct ggml_tensor * t, uint32_t * alignedOffset = nullptr) {
 526      uint64_t originalOffset = 0;
 527      auto * res = ggml_vk_find_tensor(t, originalOffset);
 528      if (!res) {
 529          static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
 530          return nullTensor;
 531      }
 532  
 533      // Create a tensor whose memory will be composed of our buffers at the correct offset
 534      const size_t nelements = ggml_nelements(t);
 535      size_t nbytes = ggml_nbytes(t);
 536  
 537      size_t vulkanOffset = ggml_vk_aligned_offset(t->buffer, originalOffset);
 538      if (alignedOffset) {
 539          *alignedOffset = originalOffset - vulkanOffset;
 540          nbytes += *alignedOffset;
 541      }
 542  
 543      return komputeManager()->tensor(
 544          t->data,
 545          nelements,
 546          nbytes, kp::Tensor::TensorDataTypes::eFloat,
 547          res->primaryMemory, res->primaryBuffer,
 548          res->stagingMemory, res->stagingBuffer,
 549          vulkanOffset);
 550  }
 551  
 552  static std::vector<uint32_t> getSpirvShader(const unsigned char* rawData, size_t size) {
 553      if (size % sizeof(uint32_t) != 0) {
 554          throw std::runtime_error("Invalid size: must be divisible by sizeof(uint32_t)");
 555      }
 556  
 557      const uint32_t* data_ptr = reinterpret_cast<const uint32_t*>(rawData);
 558      size_t count = size / sizeof(uint32_t);
 559      return std::vector<uint32_t>(data_ptr, data_ptr + count);
 560  }
 561  
 562  inline static
 563  uint32_t safe_divide(uint32_t a, uint32_t b) {
 564      if (b <= 1) {
 565          return a;
 566      }
 567      if ((a % b) != 0) {
 568          fprintf(stderr, "((%u %% %u) == %u) != 0\n", a, b, a % b);
 569          GGML_ASSERT(!"safe_divide result would've had remainder");
 570      }
 571      return a / b;
 572  }
 573  
 574  static void ggml_vk_add(
 575      kp::Sequence& seq,
 576      const std::shared_ptr<kp::Tensor>& inA,
 577      const std::shared_ptr<kp::Tensor>& inB,
 578      const std::shared_ptr<kp::Tensor>& out,
 579      uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 580      int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
 581      int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
 582      int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
 583      int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
 584      int32_t ne0,
 585      int32_t nb0,  int32_t nb1,  int32_t nb2,  int32_t nb3
 586  ) {
 587      const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv,
 588          kp::shader_data::op_add_comp_spv_len);
 589  
 590      struct PushConstants {
 591          uint32_t inAOff, inBOff, outOff;
 592          int32_t ne00;
 593          int32_t nb00, nb01, nb02, nb03;
 594          int32_t ne10, ne11, ne12, ne13;
 595          int32_t nb10, nb11, nb12, nb13;
 596          int32_t ne0;
 597          int32_t nb0, nb1, nb2, nb3;
 598      } const pushConsts {
 599          safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 600          ne00,
 601          nb00, nb01, nb02, nb03,
 602          ne10, ne11, ne12, ne13,
 603          nb10, nb11, nb12, nb13,
 604          ne0,
 605          nb0, nb1, nb2, nb3
 606      };
 607  
 608      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 609      if (!komputeManager()->hasAlgorithm(__func__)) {
 610          s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
 611      } else {
 612          s_algo = komputeManager()->getAlgorithm(__func__);
 613          s_algo->setTensors({inA, inB, out});
 614          s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
 615          s_algo->setPushConstants<PushConstants>({pushConsts});
 616          s_algo->updateDescriptors(s_kompute_context->pool.get());
 617      }
 618      seq.record<kp::OpAlgoDispatch>(s_algo);
 619  }
 620  
 621  static void ggml_vk_addrow(kp::Sequence& seq,
 622                   const std::shared_ptr<kp::Tensor>& inA,
 623                   const std::shared_ptr<kp::Tensor>& inB,
 624                   const std::shared_ptr<kp::Tensor>& out,
 625                   uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 626                   uint32_t size, uint32_t row = 0) {
 627  
 628      const static auto spirv = getSpirvShader(kp::shader_data::op_addrow_comp_spv,
 629          kp::shader_data::op_addrow_comp_spv_len);
 630  
 631      struct PushConstants {
 632          uint32_t inAOff, inBOff, outOff;
 633          uint32_t row;
 634      } const pushConsts {
 635          safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 636          row
 637      };
 638  
 639      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 640      if (!komputeManager()->hasAlgorithm(__func__))
 641          s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
 642      else {
 643          s_algo = komputeManager()->getAlgorithm(__func__);
 644          s_algo->setTensors({inA, inB, out});
 645          s_algo->setWorkgroup({size});
 646          s_algo->setPushConstants<PushConstants>({pushConsts});
 647          s_algo->updateDescriptors(s_kompute_context->pool.get());
 648      }
 649      seq.record<kp::OpAlgoDispatch>(s_algo);
 650  }
 651  
 652  static void ggml_vk_mul(
 653      kp::Sequence& seq,
 654      const std::shared_ptr<kp::Tensor>& inA,
 655      const std::shared_ptr<kp::Tensor>& inB,
 656      const std::shared_ptr<kp::Tensor>& out,
 657      uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 658      int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
 659      int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
 660      int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
 661      int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
 662      int32_t ne0,
 663      int32_t nb0,  int32_t nb1,  int32_t nb2,  int32_t nb3
 664  ) {
 665      const static auto spirv = getSpirvShader(kp::shader_data::op_mul_comp_spv,
 666          kp::shader_data::op_mul_comp_spv_len);
 667  
 668      struct PushConstants {
 669          uint32_t inAOff, inBOff, outOff;
 670          int32_t ne00;
 671          int32_t nb00, nb01, nb02, nb03;
 672          int32_t ne10, ne11, ne12, ne13;
 673          int32_t nb10, nb11, nb12, nb13;
 674          int32_t ne0;
 675          int32_t nb0, nb1, nb2, nb3;
 676      } const pushConsts {
 677          safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 678          ne00,
 679          nb00, nb01, nb02, nb03,
 680          ne10, ne11, ne12, ne13,
 681          nb10, nb11, nb12, nb13,
 682          ne0,
 683          nb0, nb1, nb2, nb3
 684      };
 685  
 686      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 687      if (!komputeManager()->hasAlgorithm(__func__)) {
 688          s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
 689      } else {
 690          s_algo = komputeManager()->getAlgorithm(__func__);
 691          s_algo->setTensors({inA, inB, out});
 692          s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
 693          s_algo->setPushConstants<PushConstants>({pushConsts});
 694          s_algo->updateDescriptors(s_kompute_context->pool.get());
 695      }
 696      seq.record<kp::OpAlgoDispatch>(s_algo);
 697  }
 698  
 699  static void ggml_vk_scale(kp::Sequence& seq,
 700                     const std::shared_ptr<kp::Tensor>& in,
 701                     const std::shared_ptr<kp::Tensor>& out,
 702                     uint32_t inOff, uint32_t outOff,
 703                     uint32_t size, float scale) {
 704      const static auto spirv_1 = getSpirvShader(
 705          kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
 706      );
 707      const static auto spirv_8 = getSpirvShader(
 708          kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
 709      );
 710  
 711      struct PushConstants {
 712          uint32_t inOff, outOff;
 713          float scale;
 714      } const pushConsts {
 715          safe_divide(inOff, 4), safe_divide(outOff, 4),
 716          scale
 717      };
 718  
 719      const auto * spirv = &spirv_1;
 720      std::string name(__func__);
 721      if (size % 8 == 0) {
 722          size /= 8;
 723          name += "_8";
 724          spirv = &spirv_8;
 725      }
 726  
 727      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 728      if (!komputeManager()->hasAlgorithm(name)) {
 729          s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
 730      } else {
 731          s_algo = komputeManager()->getAlgorithm(name);
 732          s_algo->setTensors({in, out});
 733          s_algo->setWorkgroup({size});
 734          s_algo->setPushConstants<PushConstants>({pushConsts});
 735          s_algo->updateDescriptors(s_kompute_context->pool.get());
 736      }
 737      seq.record<kp::OpAlgoDispatch>(s_algo);
 738  }
 739  
 740  static void ggml_vk_xxlu(
 741      const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
 742      const std::shared_ptr<kp::Tensor>& in,
 743      const std::shared_ptr<kp::Tensor>& out,
 744      uint32_t inOff, uint32_t outOff,
 745      uint32_t size
 746  ) {
 747      struct PushConstants {
 748          uint32_t inOff, outOff;
 749      } const pushConsts {
 750          safe_divide(inOff, 4), safe_divide(outOff, 4),
 751      };
 752  
 753      auto name = std::string(__func__) + "_" + suffix;
 754      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 755      if (!komputeManager()->hasAlgorithm(name)) {
 756          s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
 757      } else {
 758          s_algo = komputeManager()->getAlgorithm(name);
 759          s_algo->setTensors({in, out});
 760          s_algo->setWorkgroup({size});
 761          s_algo->setPushConstants<PushConstants>({pushConsts});
 762          s_algo->updateDescriptors(s_kompute_context->pool.get());
 763      }
 764      seq.record<kp::OpAlgoDispatch>(s_algo);
 765  }
 766  
 767  template <typename... Args>
 768  static void ggml_vk_silu(Args&&... args) {
 769      const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv,
 770          kp::shader_data::op_silu_comp_spv_len);
 771  
 772      ggml_vk_xxlu(spirv, "silu", std::forward<Args>(args)...);
 773  }
 774  
 775  template <typename... Args>
 776  static void ggml_vk_relu(Args&&... args) {
 777      const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv,
 778          kp::shader_data::op_relu_comp_spv_len);
 779  
 780      ggml_vk_xxlu(spirv, "relu", std::forward<Args>(args)...);
 781  }
 782  
 783  template <typename... Args>
 784  static void ggml_vk_gelu(Args&&... args) {
 785      const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv,
 786          kp::shader_data::op_gelu_comp_spv_len);
 787  
 788      ggml_vk_xxlu(spirv, "gelu", std::forward<Args>(args)...);
 789  }
 790  
 791  static void ggml_vk_soft_max(
 792      kp::Sequence& seq,
 793      const std::shared_ptr<kp::Tensor>& inA,
 794      const std::shared_ptr<kp::Tensor>& inB,
 795      const std::shared_ptr<kp::Tensor>& out,
 796      uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 797      int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
 798      float scale
 799  ) {
 800      const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
 801          kp::shader_data::op_softmax_comp_spv_len);
 802  
 803      struct PushConstants {
 804          uint32_t inAOff, inBOff, outOff;
 805          int32_t ne00, ne01, ne02;
 806          float scale;
 807          int32_t mask;
 808      } pushConsts {
 809          safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 810          ne00, ne01, ne02,
 811          scale,
 812          bool(inB)
 813      };
 814  
 815      auto & inB_ = inB ? inB : inA;
 816  
 817      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 818      if (!komputeManager()->hasAlgorithm(__func__)) {
 819          // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
 820          const uint32_t local_x = 32;
 821          s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
 822      } else {
 823          s_algo = komputeManager()->getAlgorithm(__func__);
 824          s_algo->setTensors({inA, inB_, out});
 825          s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
 826          s_algo->setPushConstants<PushConstants>({pushConsts});
 827          s_algo->updateDescriptors(s_kompute_context->pool.get());
 828      }
 829      seq.record<kp::OpAlgoDispatch>(s_algo);
 830  }
 831  
 832  static void ggml_vk_norm_(
 833      const std::vector<uint32_t>& spirv, const char * suffix, kp::Sequence& seq,
 834      const std::shared_ptr<kp::Tensor>& in,
 835      const std::shared_ptr<kp::Tensor>& out,
 836      uint32_t inOff, uint32_t outOff,
 837      int32_t ne00, int32_t nb01,
 838      int32_t nrows, float epsilon
 839  ) {
 840      GGML_ASSERT(nb01%sizeof(float) == 0);
 841      GGML_ASSERT(ne00%sizeof(float) == 0);
 842  
 843      struct PushConstants {
 844          uint32_t inOff, outOff;
 845          uint32_t ne00, nb01;
 846          float eps;
 847      } pushConsts {
 848          safe_divide(inOff, 4), safe_divide(outOff, 4),
 849          (uint32_t)ne00, (uint32_t)nb01, epsilon
 850      };
 851  
 852      auto name = std::string(__func__) + "_" + suffix;
 853      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 854      if (!komputeManager()->hasAlgorithm(name)) {
 855          s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts});
 856      } else {
 857          s_algo = komputeManager()->getAlgorithm(name);
 858          s_algo->setTensors({in, out});
 859          s_algo->setWorkgroup({(uint32_t)nrows});
 860          s_algo->setPushConstants<PushConstants>({pushConsts});
 861          s_algo->updateDescriptors(s_kompute_context->pool.get());
 862      }
 863      seq.record<kp::OpAlgoDispatch>(s_algo);
 864  }
 865  
 866  template <typename... Args>
 867  static void ggml_vk_norm(Args&&... args) {
 868      const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv,
 869          kp::shader_data::op_norm_comp_spv_len);
 870  
 871      ggml_vk_norm_(spirv, "norm", std::forward<Args>(args)...);
 872  }
 873  
 874  template <typename... Args>
 875  static void ggml_vk_rms_norm(Args&&... args) {
 876      const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv,
 877          kp::shader_data::op_rmsnorm_comp_spv_len);
 878  
 879      ggml_vk_norm_(spirv, "rms", std::forward<Args>(args)...);
 880  }
 881  
 882  static void ggml_vk_diag_mask_inf(kp::Sequence& seq,
 883                             const std::shared_ptr<kp::Tensor>& in,
 884                             const std::shared_ptr<kp::Tensor>& out,
 885                             uint32_t inOff, uint32_t outOff,
 886                             uint32_t n_past,
 887                             int32_t ne00, int32_t ne01, int32_t ne02) {
 888      const static auto spirv = getSpirvShader(kp::shader_data::op_diagmask_comp_spv,
 889          kp::shader_data::op_diagmask_comp_spv_len);
 890  
 891      struct PushConstants {
 892          uint32_t inOff, outOff;
 893          uint32_t n_past;
 894          int32_t ne00, ne01;
 895      } pushConsts {
 896          safe_divide(inOff, 4), safe_divide(outOff, 4),
 897          n_past,
 898          ne00, ne01
 899      };
 900  
 901      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 902      if (!komputeManager()->hasAlgorithm(__func__))
 903          s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts});
 904      else {
 905          s_algo = komputeManager()->getAlgorithm(__func__);
 906          s_algo->setTensors({in, out});
 907          s_algo->setWorkgroup({unsigned(ne00), unsigned(ne01), unsigned(ne02)});
 908          s_algo->setPushConstants<PushConstants>({pushConsts});
 909          s_algo->updateDescriptors(s_kompute_context->pool.get());
 910      }
 911      seq.record<kp::OpAlgoDispatch>(s_algo);
 912  }
 913  
 914  static void ggml_vk_mul_mat_f16(
 915      kp::Sequence& seq,
 916      const std::shared_ptr<kp::Tensor>& inA,
 917      const std::shared_ptr<kp::Tensor>& inB,
 918      const std::shared_ptr<kp::Tensor>& out,
 919      uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 920      int32_t ne00, int32_t ne01, int32_t ne02,
 921      uint32_t nb00, uint32_t nb01, uint32_t nb02,
 922      int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
 923      uint32_t nb10, uint32_t nb11, uint32_t nb12,
 924      int32_t ne0, int32_t ne1,
 925      uint32_t r2, uint32_t r3
 926  ) {
 927      const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_f16_comp_spv,
 928          kp::shader_data::op_mul_mat_f16_comp_spv_len);
 929  
 930      struct PushConstants {
 931          uint32_t inAOff, inBOff, outOff;
 932          int32_t ne00, ne01, ne02;
 933          uint32_t nb00, nb01, nb02;
 934          int32_t ne10, ne11, ne12;
 935          uint32_t nb10, nb11, nb12;
 936          int32_t ne0, ne1;
 937          uint32_t r2, r3;
 938      } pushConsts {
 939          safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 940          ne00, ne01, ne02,
 941          nb00, nb01, nb02,
 942          ne10, ne11, ne12,
 943          nb10, nb11, nb12,
 944          ne0, ne1,
 945          r2, r3
 946      };
 947  
 948      const unsigned ny = unsigned((ne11 + 4 - 1)/4);
 949  
 950      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 951      if (!komputeManager()->hasAlgorithm(__func__)) {
 952          const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
 953          s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), ny, unsigned(ne12*ne13)}, {local_x}, {pushConsts});
 954      } else {
 955          s_algo = komputeManager()->getAlgorithm(__func__);
 956          s_algo->setTensors({inA, inB, out});
 957          s_algo->setWorkgroup({unsigned(ne01), ny, unsigned(ne12*ne13)});
 958          s_algo->setPushConstants<PushConstants>({pushConsts});
 959          s_algo->updateDescriptors(s_kompute_context->pool.get());
 960      }
 961      seq.record<kp::OpAlgoDispatch>(s_algo);
 962  }
 963  
 964  static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq,
 965                           const std::shared_ptr<kp::Tensor>& inA,
 966                           const std::shared_ptr<kp::Tensor>& inB,
 967                           const std::shared_ptr<kp::Tensor>& out,
 968                           uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 969                           int32_t ne00, int32_t ne01, int32_t ne02,
 970                           uint32_t nb01, uint32_t nb02,
 971                           int32_t ne11, int32_t ne12,
 972                           uint32_t nb11, uint32_t nb12,
 973                           uint32_t nb1, uint32_t nb2) {
 974      const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_mat_f32_comp_spv,
 975          kp::shader_data::op_mul_mat_mat_f32_comp_spv_len);
 976  
 977      struct PushConstants {
 978          uint32_t inAOff, inBOff, outOff;
 979          int32_t ne00, ne01, ne02, ne11, ne12;
 980          uint32_t nb01, nb02;
 981          uint32_t nb11, nb12;
 982          uint32_t nb1, nb2;
 983      } pushConsts {
 984          safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 985          ne00, ne01, ne02, ne11, ne12,
 986          nb01, nb02, nb11, nb12,
 987          nb1, nb2
 988      };
 989  
 990      const uint32_t local_x = ggml_vk_current_device().subgroupSize;
 991      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 992      if (!komputeManager()->hasAlgorithm(__func__)) {
 993          s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(),
 994          {inA, inB, out}, spirv,
 995          {unsigned(ne01),
 996           unsigned(ne11),
 997           unsigned(std::max(ne12, ne02))
 998           },
 999          {local_x},
1000          {pushConsts});
1001      } else {
1002          s_algo = komputeManager()->getAlgorithm(__func__);
1003          s_algo->setTensors({inA, inB, out});
1004          s_algo->setWorkgroup({unsigned(ne01),
1005                                unsigned(ne11),
1006                                unsigned(std::max(ne12, ne02)),
1007                                });
1008          s_algo->setPushConstants<PushConstants>({pushConsts});
1009          s_algo->updateDescriptors(s_kompute_context->pool.get());
1010      }
1011      seq.record<kp::OpAlgoDispatch>(s_algo);
1012  }
1013  
1014  static void ggml_vk_mul_mat_impl(
1015      const std::vector<uint32_t>& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq,
1016      const std::shared_ptr<kp::Tensor>& inA,
1017      const std::shared_ptr<kp::Tensor>& inB,
1018      const std::shared_ptr<kp::Tensor>& out,
1019      uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1020      int32_t ne00, int32_t ne01, int32_t ne02,
1021      int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
1022      int32_t ne0, int32_t ne1,
1023      uint32_t r2, uint32_t r3
1024  ) {
1025      struct PushConstants {
1026          uint32_t inAOff, inBOff, outOff;
1027          int32_t ne00, ne01, ne02;
1028          int32_t ne10, ne12;
1029          int32_t ne0, ne1;
1030          uint32_t r2, r3;
1031      } pushConsts {
1032          safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
1033          ne00, ne01, ne02,
1034          ne10, ne12,
1035          ne0, ne1,
1036          r2, r3
1037      };
1038  
1039      auto name = std::string(__func__) + "_" + suffix;
1040      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1041      if (!komputeManager()->hasAlgorithm(name)) {
1042          const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1043          s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
1044      } else {
1045          s_algo = komputeManager()->getAlgorithm(name);
1046          s_algo->setTensors({inA, inB, out});
1047          s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)});
1048          s_algo->setPushConstants<PushConstants>({pushConsts});
1049          s_algo->updateDescriptors(s_kompute_context->pool.get());
1050      }
1051      seq.record<kp::OpAlgoDispatch>(s_algo);
1052  }
1053  
1054  template <typename... Args>
1055  static void ggml_vk_mul_mat_q4_0(Args&&... args) {
1056      const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv,
1057          kp::shader_data::op_mul_mat_q4_0_comp_spv_len);
1058  
1059      ggml_vk_mul_mat_impl(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1060  }
1061  
1062  template <typename... Args>
1063  static void ggml_vk_mul_mat_q4_1(Args&&... args) {
1064      const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv,
1065          kp::shader_data::op_mul_mat_q4_1_comp_spv_len);
1066  
1067      ggml_vk_mul_mat_impl(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1068  }
1069  
1070  template <typename... Args>
1071  static void ggml_vk_mul_mat_q8_0(Args&&... args) {
1072      const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q8_0_comp_spv,
1073          kp::shader_data::op_mul_mat_q8_0_comp_spv_len);
1074  
1075      ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
1076  }
1077  
1078  static void ggml_vk_mul_mat_q6_k(
1079      kp::Sequence& seq,
1080      const std::shared_ptr<kp::Tensor>& inA,
1081      const std::shared_ptr<kp::Tensor>& inB,
1082      const std::shared_ptr<kp::Tensor>& out,
1083      uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1084      int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
1085      int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
1086  ) {
1087      const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
1088          kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
1089  
1090      struct PushConstants {
1091          uint32_t inAOff, inBOff, outOff;
1092          int32_t ne00, ne10, ne0, ne1, ne01, gqa;
1093      } pushConsts {
1094          inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
1095          ne00, ne10, ne0, ne1, ne01, ne12/ne02
1096      };
1097  
1098      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1099      if (!komputeManager()->hasAlgorithm(__func__)) {
1100          const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
1101          s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
1102      } else {
1103          s_algo = komputeManager()->getAlgorithm(__func__);
1104          s_algo->setTensors({inA, inB, out});
1105          s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
1106          s_algo->setPushConstants<PushConstants>({pushConsts});
1107          s_algo->updateDescriptors(s_kompute_context->pool.get());
1108      }
1109      seq.record<kp::OpAlgoDispatch>(s_algo);
1110  }
1111  
1112  static void ggml_vk_get_rows(
1113      const std::vector<uint32_t>& spirv,
1114      const char * suffix,
1115      unsigned element_size, unsigned qk,
1116      kp::Sequence& seq,
1117      const std::shared_ptr<kp::Tensor>& inA,
1118      const std::shared_ptr<kp::Tensor>& inB,
1119      const std::shared_ptr<kp::Tensor>& out,
1120      uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1121      int32_t ne00, int32_t nb01, int32_t nb1,
1122      uint32_t size
1123  ) {
1124      GGML_ASSERT(nb01%element_size == 0);
1125      GGML_ASSERT(nb1%sizeof(float) == 0);
1126      if (qk) GGML_ASSERT(ne00%qk == 0);
1127  
1128      struct PushConstants {
1129          uint32_t inAOff, inBOff, outOff;
1130          int32_t ne00, nb01, nb1;
1131      } pushConsts {
1132          safe_divide(inAOff, element_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
1133          ne00, nb01, nb1
1134      };
1135  
1136      auto name = std::string(__func__) + "_" + suffix;
1137      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1138      if (!komputeManager()->hasAlgorithm(name)) {
1139          s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
1140      } else {
1141          s_algo = komputeManager()->getAlgorithm(name);
1142          s_algo->setTensors({inA, inB, out});
1143          s_algo->setWorkgroup({size});
1144          s_algo->setPushConstants<PushConstants>({pushConsts});
1145          s_algo->updateDescriptors(s_kompute_context->pool.get());
1146      }
1147      seq.record<kp::OpAlgoDispatch>(s_algo);
1148  }
1149  
1150  template <typename... Args>
1151  static void ggml_vk_get_rows_f32(Args&&... args) {
1152      const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
1153          kp::shader_data::op_getrows_f32_comp_spv_len);
1154  
1155      ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
1156  }
1157  
1158  template <typename... Args>
1159  static void ggml_vk_get_rows_f16(Args&&... args) {
1160      const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
1161          kp::shader_data::op_getrows_f16_comp_spv_len);
1162  
1163      ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward<Args>(args)...);
1164  }
1165  
1166  template <typename... Args>
1167  static void ggml_vk_get_rows_q4_0(Args&&... args) {
1168      const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv,
1169          kp::shader_data::op_getrows_q4_0_comp_spv_len);
1170  
1171      ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward<Args>(args)...);
1172  }
1173  
1174  template <typename... Args>
1175  static void ggml_vk_get_rows_q4_1(Args&&... args) {
1176      const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv,
1177          kp::shader_data::op_getrows_q4_1_comp_spv_len);
1178  
1179      ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward<Args>(args)...);
1180  }
1181  
1182  template <typename... Args>
1183  static void ggml_vk_get_rows_q6_k(Args&&... args) {
1184      const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv,
1185          kp::shader_data::op_getrows_q6_k_comp_spv_len);
1186      ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward<Args>(args)...);
1187  }
1188  
1189  static void ggml_vk_rope(
1190      kp::Sequence& seq,
1191      const std::shared_ptr<kp::Tensor>& inA,
1192      const std::shared_ptr<kp::Tensor>& inB,
1193      const std::shared_ptr<kp::Tensor>& out,
1194      uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1195      ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
1196      float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
1197      int32_t ne01, int32_t ne02, int32_t ne03,
1198      uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
1199      int32_t ne0,
1200      uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
1201  ) {
1202      GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
1203  
1204      static const auto spirv_f16 = getSpirvShader(
1205          kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
1206      );
1207      static const auto spirv_f32 = getSpirvShader(
1208          kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
1209      );
1210  
1211      int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
1212  
1213      GGML_ASSERT(nb03 % type_size == 0);
1214      GGML_ASSERT(nb02 % type_size == 0);
1215      GGML_ASSERT(nb01 % type_size == 0);
1216      GGML_ASSERT(nb00 % type_size == 0);
1217      GGML_ASSERT(nb3  % type_size == 0);
1218      GGML_ASSERT(nb2  % type_size == 0);
1219      GGML_ASSERT(nb1  % type_size == 0);
1220      GGML_ASSERT(nb0  % type_size == 0);
1221  
1222      struct PushConstants {
1223          uint32_t inAOff, inBOff, outOff;
1224          int32_t n_dims, mode, n_ctx_orig;
1225          float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1226          uint32_t nb00, nb01, nb02, nb03;
1227          int32_t ne0;
1228          uint32_t nb0, nb1, nb2, nb3;
1229      } pushConsts {
1230          safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
1231          n_dims, mode, n_ctx_orig,
1232          freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1233          nb00, nb01, nb02, nb03,
1234          ne0,
1235          nb0, nb1, nb2, nb3
1236      };
1237  
1238      auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
1239      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1240      if (!komputeManager()->hasAlgorithm(name)) {
1241          s_algo = komputeManager()->algorithm<float, PushConstants>(
1242              name, s_kompute_context->pool.get(), {inA, inB, out},
1243              src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
1244              {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
1245          );
1246      } else {
1247          s_algo = komputeManager()->getAlgorithm(name);
1248          s_algo->setTensors({inA, inB, out});
1249          s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
1250          s_algo->setPushConstants<PushConstants>({pushConsts});
1251          s_algo->updateDescriptors(s_kompute_context->pool.get());
1252      }
1253      seq.record<kp::OpAlgoDispatch>(s_algo);
1254  }
1255  
1256  static void ggml_vk_cpy(
1257      const std::vector<uint32_t>& spirv,
1258      uint32_t in_element_size, uint32_t out_element_size,
1259      kp::Sequence& seq,
1260      const std::shared_ptr<kp::Tensor>& in,
1261      const std::shared_ptr<kp::Tensor>& out,
1262      uint32_t inOff, uint32_t outOff,
1263      int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
1264      uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
1265      int32_t ne0, int32_t ne1, int32_t ne2,
1266      uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
1267  ) {
1268      struct PushConstants {
1269          uint32_t inOff, outOff;
1270          int32_t ne00, ne01, ne02;
1271          uint32_t nb00, nb01, nb02, nb03;
1272          int32_t ne0, ne1, ne2;
1273          uint32_t nb0, nb1, nb2, nb3;
1274      } pushConsts {
1275          safe_divide(inOff, in_element_size), safe_divide(outOff, out_element_size),
1276          ne00, ne01, ne02,
1277          nb00, nb01, nb02, nb03,
1278          ne0, ne1, ne2,
1279          nb0, nb1, nb2, nb3
1280      };
1281  
1282      std::string name = std::string(__func__)
1283                         + "_i_" + std::to_string(in_element_size)
1284                         + "_o_" + std::to_string(out_element_size);
1285      std::shared_ptr<kp::Algorithm> s_algo = nullptr;
1286      if (!komputeManager()->hasAlgorithm(name))
1287          s_algo = komputeManager()->algorithm<float, PushConstants>(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
1288      else {
1289          s_algo = komputeManager()->getAlgorithm(name);
1290          s_algo->setTensors({in, out});
1291          s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
1292          s_algo->setPushConstants<PushConstants>({pushConsts});
1293          s_algo->updateDescriptors(s_kompute_context->pool.get());
1294      }
1295      seq.record<kp::OpAlgoDispatch>(s_algo);
1296  }
1297  
1298  template <typename... Args>
1299  static void ggml_vk_cpy_f32_f16(Args&&... args) {
1300      const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv,
1301          kp::shader_data::op_cpy_f32_f16_comp_spv_len);
1302      ggml_vk_cpy(spirv, 4, 2, std::forward<Args>(args)...);
1303  }
1304  
1305  template <typename... Args>
1306  static void ggml_vk_cpy_f32_f32(Args&&... args) {
1307      const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv,
1308          kp::shader_data::op_cpy_f32_f32_comp_spv_len);
1309      ggml_vk_cpy(spirv, 4, 4, std::forward<Args>(args)...);
1310  }
1311  
1312  template <typename... Args>
1313  static void ggml_vk_cpy_f16_f16(Args&&... args) {
1314      const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv,
1315          kp::shader_data::op_cpy_f16_f16_comp_spv_len);
1316      ggml_vk_cpy(spirv, 2, 2, std::forward<Args>(args)...);
1317  }
1318  
1319  template <typename... Args>
1320  static void ggml_vk_cpy_f16_f32(Args&&... args) {
1321      const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv,
1322          kp::shader_data::op_cpy_f16_f32_comp_spv_len);
1323      ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
1324  }
1325  
1326  static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1327      switch (op->type) {
1328          case GGML_TYPE_F16:
1329          case GGML_TYPE_F32:
1330          case GGML_TYPE_Q4_0:
1331          case GGML_TYPE_Q4_1:
1332              break;
1333          default:
1334              return false;
1335      }
1336  
1337      switch (op->op) {
1338          case GGML_OP_UNARY:
1339              switch (ggml_get_unary_op(op)) {
1340                  case GGML_UNARY_OP_RELU:
1341                  case GGML_UNARY_OP_GELU:
1342                  case GGML_UNARY_OP_SILU:
1343                      return ggml_is_contiguous(op->src[0]);
1344                  default:
1345                      ;
1346              }
1347              break;
1348          case GGML_OP_NONE:
1349          case GGML_OP_RESHAPE:
1350          case GGML_OP_VIEW:
1351          case GGML_OP_TRANSPOSE:
1352          case GGML_OP_PERMUTE:
1353          case GGML_OP_ADD:
1354          case GGML_OP_MUL:
1355          case GGML_OP_SCALE:
1356          case GGML_OP_SOFT_MAX:
1357          case GGML_OP_RMS_NORM:
1358          case GGML_OP_NORM:
1359          case GGML_OP_ROPE:
1360              return true;
1361          case GGML_OP_DUP:
1362          case GGML_OP_CPY:
1363          case GGML_OP_CONT:
1364              switch (op->src[0]->type) {
1365                  case GGML_TYPE_F32:
1366                  case GGML_TYPE_F16:
1367                      break;
1368                  default:
1369                      return false;
1370              }
1371              switch (op->type) {
1372                  case GGML_TYPE_F32:
1373                  case GGML_TYPE_F16:
1374                      break;
1375                  default:
1376                      return false;
1377              }
1378              return true;
1379          case GGML_OP_DIAG_MASK_INF:
1380              return op->ne[3] == 1;
1381          case GGML_OP_GET_ROWS:
1382              switch (op->src[0]->type) {
1383                  case GGML_TYPE_F32:
1384                  case GGML_TYPE_F16:
1385                  case GGML_TYPE_Q4_0:
1386                  case GGML_TYPE_Q4_1:
1387                  case GGML_TYPE_Q6_K:
1388                      return op->ne[2] == 1 && op->ne[3] == 1;
1389                  default:
1390                      ;
1391              }
1392              return false;
1393          case GGML_OP_MUL_MAT:
1394              if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1]))
1395                  return false;
1396  
1397              switch (op->src[0]->type) {
1398                  case GGML_TYPE_F32:
1399                  case GGML_TYPE_Q6_K:
1400                      return op->ne[3] == 1;
1401                  case GGML_TYPE_F16:
1402                  case GGML_TYPE_Q8_0:
1403                  case GGML_TYPE_Q4_0:
1404                  case GGML_TYPE_Q4_1:
1405                      return true;
1406                  default:
1407                      ;
1408              }
1409          default:
1410              ;
1411      }
1412      return false;
1413  }
1414  
1415  static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
1416      const int n_seq = 8;
1417  
1418      // FIXME: Figure out if we can somehow optimize the size of the pool... right now we're setting
1419      // it to the size of the graph, but I think it can be made smaller?
1420      ggml_vk_allocate_descriptor_pool(ctx, gf->n_nodes);
1421  
1422      std::vector<std::shared_ptr<kp::Sequence>> sequences(n_seq);
1423  
1424      for (auto& sequence : sequences) {
1425          sequence = komputeManager()->sequence();
1426      }
1427      for (int seq_idx = 0; seq_idx < n_seq; ++seq_idx) {
1428          const int n_nodes_per_seq = (gf->n_nodes + n_seq - 1) / n_seq;
1429  
1430          auto& seq = *sequences[seq_idx];
1431  
1432          const int node_start = (seq_idx + 0) * n_nodes_per_seq;
1433          const int node_end   = std::min((seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq, gf->n_nodes);
1434  
1435          bool any_commands_recorded = false;
1436  
1437          for (int i = node_start; i < node_end; ++i) {
1438              struct ggml_tensor * src0 = gf->nodes[i]->src[0];
1439              struct ggml_tensor * src1 = gf->nodes[i]->src[1];
1440              struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
1441              struct ggml_tensor * dst = gf->nodes[i];
1442              GGML_ASSERT(dst->data != nullptr);
1443  
1444              if (ggml_is_empty(dst)) {
1445                  continue;
1446              }
1447  
1448              switch (dst->op) {
1449                  case GGML_OP_NONE:
1450                  case GGML_OP_RESHAPE:
1451                  case GGML_OP_VIEW:
1452                  case GGML_OP_TRANSPOSE:
1453                  case GGML_OP_PERMUTE:
1454                      continue; // noop -> next node
1455                  default:
1456                      break;
1457              }
1458  
1459              any_commands_recorded = true;
1460  
1461              if (!ggml_vk_supports_op(dst)) {
1462                   fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1463                   GGML_ASSERT(!"unsupported op");
1464              }
1465  
1466              const int32_t ne00 = src0 ? src0->ne[0] : 0;
1467              const int32_t ne01 = src0 ? src0->ne[1] : 0;
1468              const int32_t ne02 = src0 ? src0->ne[2] : 0;
1469              const int32_t ne03 = src0 ? src0->ne[3] : 0;
1470  
1471              const uint32_t nb00 = src0 ? src0->nb[0] : 0;
1472              const uint32_t nb01 = src0 ? src0->nb[1] : 0;
1473              const uint32_t nb02 = src0 ? src0->nb[2] : 0;
1474              const uint32_t nb03 = src0 ? src0->nb[3] : 0;
1475  
1476              const int32_t ne10 = src1 ? src1->ne[0] : 0;
1477              const int32_t ne11 = src1 ? src1->ne[1] : 0;
1478              const int32_t ne12 = src1 ? src1->ne[2] : 0;
1479              const int32_t ne13 = src1 ? src1->ne[3] : 0;
1480  
1481              const uint32_t nb10 = src1 ? src1->nb[0] : 0;
1482              const uint32_t nb11 = src1 ? src1->nb[1] : 0;
1483              const uint32_t nb12 = src1 ? src1->nb[2] : 0;
1484              const uint32_t nb13 = src1 ? src1->nb[3] : 0;
1485  
1486              const int32_t ne0 = dst ? dst->ne[0] : 0;
1487              const int32_t ne1 = dst ? dst->ne[1] : 0;
1488              const int32_t ne2 = dst ? dst->ne[2] : 0;
1489  //            const int32_t ne3 = dst ? dst->ne[3] : 0;
1490  
1491              const uint32_t nb0 = dst ? dst->nb[0] : 0;
1492              const uint32_t nb1 = dst ? dst->nb[1] : 0;
1493              const uint32_t nb2 = dst ? dst->nb[2] : 0;
1494              const uint32_t nb3 = dst ? dst->nb[3] : 0;
1495  
1496              const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
1497              const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
1498              const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
1499  
1500              const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
1501              uint32_t off_src0 = 0;
1502              uint32_t off_src1 = 0;
1503              uint32_t off_dst  = 0;
1504              const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
1505              const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
1506              const std::shared_ptr<kp::Tensor>& id_dst  = dst  ? ggml_vk_get_tensor(dst,  &off_dst)  : nullTensor;
1507  
1508              switch (dst->op) {
1509                  case GGML_OP_ADD:
1510                      {
1511                          if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1512                              // src1 is a row
1513                              ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
1514                          } else {
1515                              ggml_vk_add(
1516                                  seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1517                                  ne00, ne01, ne02, ne03,
1518                                  nb00, nb01, nb02, nb03,
1519                                  ne10, ne11, ne12, ne13,
1520                                  nb10, nb11, nb12, nb13,
1521                                  ne0,
1522                                  nb0, nb1, nb2, nb3
1523                              );
1524                          }
1525                      } break;
1526                  case GGML_OP_MUL:
1527                      {
1528                          ggml_vk_mul(
1529                              seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1530                              ne00, ne01, ne02, ne03,
1531                              nb00, nb01, nb02, nb03,
1532                              ne10, ne11, ne12, ne13,
1533                              nb10, nb11, nb12, nb13,
1534                              ne0,
1535                              nb0, nb1, nb2, nb3
1536                          );
1537                      } break;
1538                  case GGML_OP_SCALE:
1539                      {
1540                          float scale; memcpy(&scale, dst->op_params, sizeof(float));
1541  
1542                          ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
1543                      } break;
1544                  case GGML_OP_UNARY:
1545                      {
1546                          int64_t n = ggml_nelements(dst);
1547                          GGML_ASSERT(n % 4 == 0);
1548                          switch (ggml_get_unary_op(gf->nodes[i])) {
1549                              case GGML_UNARY_OP_SILU:
1550                                  {
1551                                      ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
1552                                  } break;
1553                              case GGML_UNARY_OP_RELU:
1554                                  {
1555                                      ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
1556                                  } break;
1557                              case GGML_UNARY_OP_GELU:
1558                                  {
1559                                      GGML_ASSERT(n % 8 == 0);
1560                                      ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8);
1561                                  } break;
1562                              default:
1563                                  {
1564                                      fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1565                                      GGML_ASSERT(false);
1566                                  }
1567                          }
1568                      } break;
1569                  case GGML_OP_SOFT_MAX:
1570                      {
1571                          float scale;
1572                          float max_bias;
1573  
1574                          memcpy(&scale,    (float *)dst->op_params + 0, sizeof(float));
1575                          memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
1576  
1577  #pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
1578  #pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
1579                          GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1580  
1581  #pragma message("TODO: add ALiBi support")
1582  #pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/7192")
1583                          GGML_ASSERT(max_bias == 0.0f);
1584  
1585                          ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1586                      } break;
1587                  case GGML_OP_DIAG_MASK_INF:
1588                      {
1589                          const int n_past = ((int32_t *)(dst->op_params))[0];
1590                          ggml_vk_diag_mask_inf(seq, id_src0, id_dst, off_src0, off_dst, n_past, ne00, ne01, ne02);
1591                      } break;
1592                  case GGML_OP_NORM:
1593                      {
1594                          float eps;
1595                          memcpy(&eps, dst->op_params, sizeof(float));
1596                          ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
1597                      } break;
1598                  case GGML_OP_RMS_NORM:
1599                      {
1600                          GGML_ASSERT(ne00 % 4 == 0);
1601  
1602                          float eps;
1603                          memcpy(&eps, dst->op_params, sizeof(float));
1604                          ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
1605                      } break;
1606                  case GGML_OP_MUL_MAT:
1607                      {
1608                          GGML_ASSERT(ne00 == ne10);
1609  
1610                          GGML_ASSERT(ne12 % ne02 == 0);
1611                          GGML_ASSERT(ne13 % ne03 == 0);
1612  
1613                          const uint32_t r2 = ne12/ne02;
1614                          const uint32_t r3 = ne13/ne03;
1615  
1616                          if (src1t != GGML_TYPE_F32) {
1617                              fprintf(stderr, "%s: %s: Unsupported src1 type: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1618                              goto not_implemented;
1619                          }
1620  
1621                          if (ggml_is_transposed(src0) ||
1622                              ggml_is_transposed(src1)) {
1623                              fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1624                              goto not_implemented;
1625                          }
1626  
1627                          switch (src0t) {
1628                              case GGML_TYPE_F32:
1629                                  ggml_vk_mul_mat_mat_f32(
1630                                      seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1631                                      ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, nb1, nb2
1632                                  );
1633                                  break;
1634                              case GGML_TYPE_F16:
1635                                  ggml_vk_mul_mat_f16(
1636                                      seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1637                                      ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
1638                                      ne0, ne1, r2, r3
1639                                  );
1640                                  break;
1641                              case GGML_TYPE_Q8_0:
1642                                  ggml_vk_mul_mat_q8_0(
1643                                      seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1644                                      ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1645                                  );
1646                                  break;
1647                              case GGML_TYPE_Q4_0:
1648                                  ggml_vk_mul_mat_q4_0(
1649                                      seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1650                                      ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1651                                  );
1652                                  break;
1653                              case GGML_TYPE_Q4_1:
1654                                  ggml_vk_mul_mat_q4_1(
1655                                      seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1656                                      ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
1657                                  );
1658                                  break;
1659                              case GGML_TYPE_Q6_K:
1660                                  ggml_vk_mul_mat_q6_k(
1661                                      seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
1662                                      ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
1663                                  );
1664                                  break;
1665                              default: {
1666                                  fprintf(stderr, "%s: %s: Unsupported quantization: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
1667                                  goto not_implemented;
1668                              }
1669                          }
1670  
1671                      } break;
1672                  case GGML_OP_GET_ROWS:
1673                      {
1674                          if (src0t == GGML_TYPE_F32) {
1675                              ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1676                          } else if (src0t == GGML_TYPE_F16) {
1677                              ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1678                          } else if (src0t == GGML_TYPE_Q4_0) {
1679                              ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1680                          } else if (src0t == GGML_TYPE_Q4_1) {
1681                              ggml_vk_get_rows_q4_1(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1682                          } else if (src0t == GGML_TYPE_Q6_K) {
1683                              ggml_vk_get_rows_q6_k(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
1684                          } else {
1685                              fprintf(stderr, "%s: %s: Unsupported quantization: %u\n", __func__, ggml_op_name(dst->op), src0t);
1686                              goto not_implemented;
1687                          }
1688                      } break;
1689                  case GGML_OP_ROPE:
1690                      {
1691  #pragma message("TODO: implement phi3 frequency factors support")
1692  #pragma message("      https://github.com/ggerganov/llama.cpp/pull/7225")
1693                          GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
1694  
1695  #pragma message("TODO: update rope NORM mode to match NEOX mode")
1696  #pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634")
1697  
1698                          GGML_ASSERT(ne10 == ne02);
1699                          GGML_ASSERT(src0t == dstt);
1700                          // const int n_past = ((int32_t *) dst->op_params)[0];
1701                          const int n_dims     = ((int32_t *) dst->op_params)[1];
1702                          const int mode       = ((int32_t *) dst->op_params)[2];
1703                          // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1704                          const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
1705  
1706                          float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1707                          memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
1708                          memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
1709                          memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
1710                          memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
1711                          memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
1712                          memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
1713                          ggml_vk_rope(
1714                              seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
1715                              freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1716                              ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1717                          );
1718                      } break;
1719                  case GGML_OP_DUP:
1720                  case GGML_OP_CPY:
1721                  case GGML_OP_CONT:
1722                      {
1723                          switch (src0t) {
1724                              case GGML_TYPE_F32:
1725                                  {
1726                                      switch (dstt) {
1727                                          case GGML_TYPE_F16: ggml_vk_cpy_f32_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1728                                          case GGML_TYPE_F32: ggml_vk_cpy_f32_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1729                                          default: goto not_implemented;
1730                                      }
1731                                  } break;
1732                              case GGML_TYPE_F16:
1733                                  {
1734                                      switch (dstt) {
1735                                          case GGML_TYPE_F16: ggml_vk_cpy_f16_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1736                                          case GGML_TYPE_F32: ggml_vk_cpy_f16_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
1737                                      default: goto not_implemented;
1738                                  } break;
1739                              default: goto not_implemented;
1740                              }
1741                          }
1742                      } break;
1743                  default: goto not_implemented;
1744              }
1745              continue;
1746              not_implemented: {}
1747              fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1748              //GGML_ASSERT(false);
1749          }
1750  
1751          // Evaluate sequence
1752          if (any_commands_recorded) {
1753              seq.evalAsync();
1754          }
1755      }
1756  
1757      // Wait for all sequences to finish
1758      for (auto& sequence : sequences) {
1759          if (sequence->isRunning())
1760              sequence->evalAwait();
1761      }
1762  
1763      ggml_vk_free_descriptor_pool(ctx);
1764  }
1765  
1766  template<>
1767  kp::Tensor::TensorDataTypes
1768  kp::TensorT<half>::dataType()
1769  {
1770      return TensorDataTypes::eFloat;
1771  }
1772  
1773  template<>
1774  kp::Tensor::TensorDataTypes
1775  kp::TensorT<uint8_t>::dataType()
1776  {
1777      return TensorDataTypes::eUnsignedInt;
1778  }
1779  
1780  ////////////////////////////////////////////////////////////////////////////////
1781  
1782  // backend interface
1783  
1784  struct ggml_backend_kompute_buffer_type_context {
1785      int         device;
1786      int         device_ref = 0;
1787      uint64_t    buffer_alignment;
1788      uint64_t    max_alloc;
1789      std::string name;
1790  
1791      ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment, uint64_t max_alloc)
1792          : device(device), buffer_alignment(buffer_alignment), max_alloc(max_alloc), name(ggml_kompute_format_name(device)) {}
1793  };
1794  
1795  static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) {
1796      auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1797  
1798      if (!ctx->device_ref) {
1799          komputeManager()->initializeDevice(
1800              ctx->device, {}, {
1801                  "VK_KHR_shader_float16_int8", "VK_KHR_8bit_storage",
1802                  "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info"
1803              }
1804          );
1805      }
1806  
1807      assert(ggml_vk_has_device());
1808      ctx->device_ref++;
1809  }
1810  
1811  static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
1812      auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1813  
1814      assert(ctx->device_ref > 0);
1815  
1816      ctx->device_ref--;
1817  
1818      if (!ctx->device_ref) {
1819          komputeManager.destroy();
1820      }
1821  }
1822  
1823  static const char * ggml_backend_kompute_buffer_get_name(ggml_backend_buffer_t buffer) {
1824      auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buffer->buft->context);
1825      return ctx->name.c_str();
1826  }
1827  
1828  static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1829      auto * memory = (ggml_vk_memory *)buffer->context;
1830      if (ggml_vk_has_device()) {
1831          ggml_vk_free_memory(*memory);
1832      }
1833      delete memory;
1834  }
1835  
1836  static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
1837      return ((ggml_vk_memory *)buffer->context)->data;
1838  }
1839  
1840  static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1841      GGML_UNUSED(buffer);
1842  
1843      const auto res = ggml_vk_get_tensor(tensor);
1844      GGML_ASSERT(res);
1845  
1846      memcpy((char *)tensor->data + offset, data, size);
1847  
1848      komputeManager()->sequence()->eval<kp::OpTensorSyncDevice>({res});
1849  }
1850  
1851  static void ggml_backend_kompute_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1852      GGML_UNUSED(buffer);
1853  
1854      const auto res = ggml_vk_get_tensor(tensor);
1855      GGML_ASSERT(res);
1856  
1857      komputeManager()->sequence()->eval<kp::OpTensorSyncLocal>({res});
1858  
1859      memcpy(data, (const char *)tensor->data + offset, size);
1860  }
1861  
1862  static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1863      auto * memory = (ggml_vk_memory *)buffer->context;
1864      memset(memory->data, value, buffer->size);
1865  
1866      if (memory->stagingBuffer)
1867          komputeManager()->sequence()->eval<kp::OpBufferSyncDevice>(memory->primaryBuffer, memory->stagingBuffer, memory->size);
1868  }
1869  
1870  static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
1871      /* .get_name        = */ ggml_backend_kompute_buffer_get_name,
1872      /* .free_buffer     = */ ggml_backend_kompute_buffer_free_buffer,
1873      /* .get_base        = */ ggml_backend_kompute_buffer_get_base,
1874      /* .init_tensor     = */ NULL,
1875      /* .set_tensor      = */ ggml_backend_kompute_buffer_set_tensor,
1876      /* .get_tensor      = */ ggml_backend_kompute_buffer_get_tensor,
1877      /* .cpy_tensor      = */ NULL,
1878      /* .clear           = */ ggml_backend_kompute_buffer_clear,
1879      /* .reset           = */ NULL,
1880  };
1881  
1882  // default buffer type
1883  
1884  static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
1885      auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1886      return ctx->name.c_str();
1887  }
1888  
1889  static ggml_backend_buffer_t ggml_backend_kompute_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1890      ggml_backend_kompute_device_ref(buft);
1891      auto * ctx = new ggml_vk_memory(ggml_vk_allocate(size));
1892      return ggml_backend_buffer_init(buft, ggml_backend_kompute_buffer_i, ctx, size);
1893  }
1894  
1895  static size_t ggml_backend_kompute_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1896      auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1897      return ctx->buffer_alignment;
1898  }
1899  
1900  static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
1901      auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buft->context);
1902      return ctx->max_alloc;
1903  }
1904  
1905  static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
1906      /* .get_name         = */ ggml_backend_kompute_buffer_type_get_name,
1907      /* .alloc_buffer     = */ ggml_backend_kompute_buffer_type_alloc_buffer,
1908      /* .get_alignment    = */ ggml_backend_kompute_buffer_type_get_alignment,
1909      /* .get_max_size     = */ ggml_backend_vk_buffer_type_get_max_size,
1910      /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
1911      /* .is_host          = */ NULL,
1912  };
1913  
1914  ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
1915      static std::vector<ggml_backend_buffer_type> bufts = []() {
1916          std::vector<ggml_backend_buffer_type> vec;
1917          auto devices = ggml_vk_available_devices_internal(0);
1918          vec.reserve(devices.size());
1919  
1920          for (const auto & dev : devices) {
1921              vec.push_back({
1922                  /* .iface   = */ ggml_backend_kompute_buffer_type_interface,
1923                  /* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc)
1924              });
1925          }
1926          return vec;
1927      }();
1928  
1929      auto it = std::find_if(bufts.begin(), bufts.end(), [device](const ggml_backend_buffer_type & t) {
1930          return device == static_cast<ggml_backend_kompute_buffer_type_context *>(t.context)->device;
1931      });
1932      return it < bufts.end() ? &*it : nullptr;
1933  }
1934  
1935  // backend
1936  
1937  static const char * ggml_backend_kompute_name(ggml_backend_t backend) {
1938      auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1939      return ctx->name.c_str();
1940  }
1941  
1942  static void ggml_backend_kompute_free(ggml_backend_t backend) {
1943      auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1944  
1945      assert(ctx == s_kompute_context);
1946      s_kompute_context = nullptr;
1947      if (ctx != nullptr) {
1948          delete ctx;
1949      }
1950  
1951      delete backend;
1952  }
1953  
1954  static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(ggml_backend_t backend) {
1955      auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1956      return ggml_backend_kompute_buffer_type(ctx->device);
1957  }
1958  
1959  static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1960      auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1961      ggml_vk_graph_compute(ctx, cgraph);
1962      return GGML_STATUS_SUCCESS;
1963  }
1964  
1965  static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1966      GGML_UNUSED(backend);
1967      return ggml_vk_supports_op(op);
1968  }
1969  
1970  static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
1971      GGML_UNUSED(backend);
1972      return buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name;
1973  }
1974  
1975  static struct ggml_backend_i kompute_backend_i = {
1976      /* .get_name                = */ ggml_backend_kompute_name,
1977      /* .free                    = */ ggml_backend_kompute_free,
1978      /* .get_default_buffer_type = */ ggml_backend_kompute_get_default_buffer_type,
1979      /* .set_tensor_async        = */ NULL,
1980      /* .get_tensor_async        = */ NULL,
1981      /* .cpy_tensor_async        = */ NULL,
1982      /* .synchronize             = */ NULL,
1983      /* .graph_plan_create       = */ NULL,
1984      /* .graph_plan_free         = */ NULL,
1985      /* .graph_plan_update       = */ NULL,
1986      /* .graph_plan_compute      = */ NULL,
1987      /* .graph_compute           = */ ggml_backend_kompute_graph_compute,
1988      /* .supports_op             = */ ggml_backend_kompute_supports_op,
1989      /* .supports_buft           = */ ggml_backend_kompute_supports_buft,
1990      /* .offload_op              = */ NULL,
1991      /* .event_new               = */ NULL,
1992      /* .event_free              = */ NULL,
1993      /* .event_record            = */ NULL,
1994      /* .event_wait              = */ NULL,
1995      /* .event_synchronize       = */ NULL,
1996  };
1997  
1998  static ggml_guid_t ggml_backend_kompute_guid() {
1999      static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
2000      return &guid;
2001  }
2002  
2003  ggml_backend_t ggml_backend_kompute_init(int device) {
2004      GGML_ASSERT(s_kompute_context == nullptr);
2005      s_kompute_context = new ggml_kompute_context(device);
2006  
2007      ggml_backend_t kompute_backend = new ggml_backend {
2008          /* .guid      = */ ggml_backend_kompute_guid(),
2009          /* .interface = */ kompute_backend_i,
2010          /* .context   = */ s_kompute_context,
2011      };
2012  
2013      return kompute_backend;
2014  }
2015  
2016  bool ggml_backend_is_kompute(ggml_backend_t backend) {
2017      return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
2018  }
2019  
2020  static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) {
2021      GGML_UNUSED(params);
2022      return ggml_backend_kompute_init(intptr_t(user_data));
2023  }
2024  
2025  extern "C" int ggml_backend_kompute_reg_devices();
2026  
2027  int ggml_backend_kompute_reg_devices() {
2028      auto devices = ggml_vk_available_devices_internal(0);
2029      for (const auto & device : devices) {
2030          ggml_backend_register(
2031              ggml_kompute_format_name(device.index).c_str(),
2032              ggml_backend_reg_kompute_init,
2033              ggml_backend_kompute_buffer_type(device.index),
2034              reinterpret_cast<void *>(intptr_t(device.index))
2035          );
2036      }
2037      return devices.size();
2038  }