/ examples / export-lora / export-lora.cpp
export-lora.cpp
  1  
  2  #include "common.h"
  3  #include "ggml.h"
  4  #include "ggml-alloc.h"
  5  
  6  #include <vector>
  7  #include <string>
  8  #include <thread>
  9  
 10  struct lora_info {
 11      std::string filename;
 12      float scale;
 13  };
 14  
 15  struct export_lora_params {
 16      std::string fn_model_base;
 17      std::string fn_model_out;
 18      std::vector<struct lora_info> lora;
 19      int n_threads;
 20  };
 21  
 22  struct lora_data {
 23      struct lora_info     info;
 24      std::vector<uint8_t> data;
 25      struct ggml_context * ctx;
 26  
 27      uint32_t lora_r;
 28      uint32_t lora_alpha;
 29  };
 30  
 31  struct llama_file {
 32      // use FILE * so we don't have to re-open the file to mmap
 33      FILE * fp;
 34      size_t size;
 35  
 36      llama_file(const char * fname, const char * mode) {
 37          fp = std::fopen(fname, mode);
 38          if (fp == NULL) {
 39              size = 0;
 40          } else {
 41              seek(0, SEEK_END);
 42              size = tell();
 43              seek(0, SEEK_SET);
 44          }
 45      }
 46  
 47      size_t tell() const {
 48  #ifdef _WIN32
 49          __int64 ret = _ftelli64(fp);
 50  #else
 51          long ret = std::ftell(fp);
 52  #endif
 53          GGML_ASSERT(ret != -1); // this really shouldn't fail
 54          return (size_t) ret;
 55      }
 56  
 57      void seek(size_t offset, int whence) {
 58  #ifdef _WIN32
 59          int ret = _fseeki64(fp, (__int64) offset, whence);
 60  #else
 61          int ret = std::fseek(fp, (long) offset, whence);
 62  #endif
 63          GGML_ASSERT(ret == 0); // same
 64      }
 65  
 66      void read_raw(void * ptr, size_t size) {
 67          if (size == 0) {
 68              return;
 69          }
 70          errno = 0;
 71          std::size_t ret = std::fread(ptr, size, 1, fp);
 72          if (ferror(fp)) {
 73              die_fmt("read error: %s", strerror(errno));
 74          }
 75          if (ret != 1) {
 76              die("unexpectedly reached end of file");
 77          }
 78      }
 79  
 80      std::uint32_t read_u32() {
 81          std::uint32_t ret;
 82          read_raw(&ret, sizeof(ret));
 83          return ret;
 84      }
 85  
 86      std::string read_string(std::uint32_t len) {
 87          std::vector<char> chars(len);
 88          read_raw(chars.data(), len);
 89          return std::string(chars.data(), len);
 90      }
 91  
 92      void write_raw(const void * ptr, size_t size) {
 93          if (size == 0) {
 94              return;
 95          }
 96          errno = 0;
 97          size_t ret = std::fwrite(ptr, size, 1, fp);
 98          if (ret != 1) {
 99              die_fmt("write error: %s", strerror(errno));
100          }
101      }
102  
103      void write_u32(std::uint32_t val) {
104          write_raw(&val, sizeof(val));
105      }
106  
107      bool eof() {
108          return tell() >= size;
109      }
110  
111      ~llama_file() {
112          if (fp) {
113              std::fclose(fp);
114          }
115      }
116  };
117  
118  static struct export_lora_params get_default_export_lora_params() {
119      struct export_lora_params result;
120      result.fn_model_base = "";
121      result.fn_model_out  = "";
122      result.n_threads = GGML_DEFAULT_N_THREADS;
123      return result;
124  }
125  
126  static void export_lora_print_usage(int /*argc*/, char ** argv, const struct export_lora_params * params) {
127      fprintf(stderr, "usage: %s [options]\n", argv[0]);
128      fprintf(stderr, "\n");
129      fprintf(stderr, "options:\n");
130      fprintf(stderr, "  -h, --help                         show this help message and exit\n");
131      fprintf(stderr, "  -m FNAME, --model-base FNAME       model path from which to load base model (default '%s')\n", params->fn_model_base.c_str());
132      fprintf(stderr, "  -o FNAME, --model-out FNAME        path to save exported model (default '%s')\n", params->fn_model_out.c_str());
133      fprintf(stderr, "  -l FNAME, --lora FNAME             apply LoRA adapter\n");
134      fprintf(stderr, "  -s FNAME S, --lora-scaled FNAME S  apply LoRA adapter with user defined scaling S\n");
135      fprintf(stderr, "  -t N, --threads N                  number of threads to use during computation (default: %d)\n", params->n_threads);
136  }
137  
138  static bool export_lora_params_parse(int argc, char ** argv, struct export_lora_params * params) {
139      bool invalid_param = false;
140      std::string arg;
141      struct export_lora_params default_params = get_default_export_lora_params();
142      const std::string arg_prefix = "--";
143  
144      for (int i = 1; i < argc; i++) {
145          arg = argv[i];
146          if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
147              std::replace(arg.begin(), arg.end(), '_', '-');
148          }
149  
150          if (arg == "-m" || arg == "--model-base") {
151              if (++i >= argc) {
152                  invalid_param = true;
153                  break;
154              }
155              params->fn_model_base = argv[i];
156          } else if (arg == "-o" || arg == "--model-out") {
157              if (++i >= argc) {
158                  invalid_param = true;
159                  break;
160              }
161              params->fn_model_out = argv[i];
162          } else if (arg == "-l" || arg == "--lora") {
163              if (++i >= argc) {
164                  invalid_param = true;
165                  break;
166              }
167              struct lora_info lora;
168              lora.filename = argv[i];
169              lora.scale = 1.0f;
170              params->lora.push_back(lora);
171          } else if (arg == "-s" || arg == "--lora-scaled") {
172              if (++i >= argc) {
173                  invalid_param = true;
174                  break;
175              }
176              struct lora_info lora;
177              lora.filename = argv[i];
178              if (++i >= argc) {
179                  invalid_param = true;
180                  break;
181              }
182              lora.scale = std::stof(argv[i]);
183              params->lora.push_back(lora);
184          } else if (arg == "-t" || arg == "--threads") {
185              if (++i >= argc) {
186                  invalid_param = true;
187                  break;
188              }
189              params->n_threads = std::stoi(argv[i]);
190              if (params->n_threads <= 0) {
191                  params->n_threads = std::thread::hardware_concurrency();
192              }
193          } else {
194              fprintf(stderr, "error: unknown argument: '%s'\n", arg.c_str());
195              export_lora_print_usage(argc, argv, &default_params);
196              exit(1);
197          }
198      }
199  
200      if (params->fn_model_base == default_params.fn_model_base) {
201          fprintf(stderr, "error: please specify a filename for model-base.\n");
202          export_lora_print_usage(argc, argv, &default_params);
203          exit(1);
204      }
205      if (params->fn_model_out == default_params.fn_model_out) {
206          fprintf(stderr, "error: please specify a filename for model-out.\n");
207          export_lora_print_usage(argc, argv, &default_params);
208          exit(1);
209      }
210      if (invalid_param) {
211          fprintf(stderr, "error: invalid parameter for argument: '%s'\n", arg.c_str());
212          export_lora_print_usage(argc, argv, &default_params);
213          exit(1);
214      }
215      return true;
216  }
217  
218  static void free_lora(struct lora_data * lora) {
219      if (lora->ctx != NULL) {
220          ggml_free(lora->ctx);
221      }
222      delete lora;
223  }
224  
225  static struct lora_data * load_lora(struct lora_info * info) {
226      struct lora_data * result = new struct lora_data;
227      result->info = *info;
228      result->ctx = NULL;
229      result->lora_r     = 1;
230      result->lora_alpha = 1;
231  
232      struct llama_file file(info->filename.c_str(), "rb");
233      if (file.fp == NULL) {
234          fprintf(stderr, "warning: Could not open lora adapter '%s'. Ignoring this adapter.\n",
235              info->filename.c_str());
236          free_lora(result);
237          return NULL;
238      }
239  
240      struct ggml_init_params params_ggml;
241      params_ggml.mem_size   = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE;
242      params_ggml.mem_buffer = NULL;
243      params_ggml.no_alloc   = true;
244      result->ctx = ggml_init(params_ggml);
245  
246      uint32_t magic   = file.read_u32();
247      if (magic != LLAMA_FILE_MAGIC_GGLA) {
248          die_fmt("unexpected lora header file magic in '%s'", info->filename.c_str());
249      }
250      uint32_t version = file.read_u32();
251      if (version != 1) {
252          die_fmt("unexpected lora file version '%u' in '%s'", (unsigned) version, info->filename.c_str());
253      }
254      result->lora_r     = file.read_u32();
255      result->lora_alpha = file.read_u32();
256      // read tensor infos from file
257      std::vector<char> name_buf;
258      std::vector<struct ggml_tensor *> tensors;
259      std::vector<size_t> tensors_offset;
260      size_t total_nbytes_pad = 0;
261      while(!file.eof()) {
262          int64_t ne[4]   = {1,1,1,1};
263          uint32_t n_dims  = file.read_u32();
264          uint32_t namelen = file.read_u32();
265          uint32_t type    = file.read_u32();
266          for (uint32_t k = 0; k < n_dims; ++k) {
267              ne[k] = (int64_t)file.read_u32();
268          }
269          name_buf.clear();
270          name_buf.resize(namelen + 1, '\0');
271          file.read_raw(name_buf.data(), namelen);
272          file.seek((0-file.tell()) & 31, SEEK_CUR);
273          size_t offset = file.tell();
274          struct ggml_tensor * tensor = ggml_new_tensor(result->ctx, (enum ggml_type) type, n_dims, ne);
275          ggml_set_name(tensor, name_buf.data());
276          size_t nbytes     = ggml_nbytes(tensor);
277          size_t nbytes_pad = ggml_nbytes_pad(tensor);
278          total_nbytes_pad += nbytes_pad;
279          tensors.push_back(tensor);
280          tensors_offset.push_back(offset);
281          file.seek(nbytes, SEEK_CUR);
282      }
283      // read tensor data
284      result->data.resize(total_nbytes_pad);
285      size_t data_offset = 0;
286      for (size_t i = 0; i < tensors.size(); ++i) {
287          struct ggml_tensor * tensor = tensors[i];
288          size_t offset     = tensors_offset[i];
289          size_t nbytes     = ggml_nbytes(tensor);
290          size_t nbytes_pad = ggml_nbytes_pad(tensor);
291          file.seek(offset, SEEK_SET);
292          tensor->data = result->data.data() + data_offset;
293          file.read_raw(tensor->data, nbytes);
294          data_offset += nbytes_pad;
295      }
296      return result;
297  }
298  
299  
300  static struct ggml_cgraph * build_graph_lora(
301      struct ggml_context * ctx,
302      struct ggml_tensor * tensor,
303      struct ggml_tensor * lora_a,
304      struct ggml_tensor * lora_b,
305      float scaling
306  ) {
307      struct ggml_tensor * ab = ggml_mul_mat(ctx, lora_a, lora_b);
308      if (scaling != 1.0f) {
309          ab = ggml_scale(ctx, ab, scaling);
310      }
311      struct ggml_tensor * res = ggml_add_inplace(ctx, tensor, ab);
312  
313      struct ggml_cgraph * gf = ggml_new_graph(ctx);
314      ggml_build_forward_expand (gf, res);
315      return gf;
316  }
317  
318  static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int n_threads) {
319      if (lora->ctx == NULL) {
320          return false;
321      }
322      std::string name = ggml_get_name(tensor);
323      std::string name_a = name + std::string(".loraA");
324      std::string name_b = name + std::string(".loraB");
325      struct ggml_tensor * lora_a = ggml_get_tensor(lora->ctx, name_a.c_str());
326      struct ggml_tensor * lora_b = ggml_get_tensor(lora->ctx, name_b.c_str());
327      if (lora_a == NULL || lora_b == NULL) {
328          return false;
329      }
330  
331      float scaling = lora->info.scale * (float)lora->lora_alpha / (float)lora->lora_r;
332  
333      struct ggml_init_params params;
334      params.mem_size   = GGML_OBJECT_SIZE + ggml_graph_overhead() + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
335      params.mem_buffer = NULL;
336      params.no_alloc   = true;
337      struct ggml_context * ctx = NULL;
338      struct ggml_gallocr * alloc = NULL;
339      struct ggml_cgraph  * gf = NULL;
340  
341      ctx   = ggml_init(params);
342      alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
343      gf    = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
344  
345      ggml_gallocr_alloc_graph(alloc, gf);
346  
347      struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads);
348      static std::vector<uint8_t> data_work;
349      data_work.resize(cplan.work_size);
350      cplan.work_data = data_work.data();
351  
352      ggml_graph_compute(gf, &cplan);
353  
354      ggml_gallocr_free(alloc);
355      ggml_free(ctx);
356      return true;
357  }
358  
359  static void export_lora(struct export_lora_params * params) {
360      // load all loras
361      std::vector<struct lora_data *> loras;
362      for (size_t i = 0; i < params->lora.size(); ++i) {
363          struct lora_data * lora = load_lora(&params->lora[i]);
364          if (lora != NULL) {
365              loras.push_back(lora);
366          }
367      }
368      if (loras.size() == 0) {
369          fprintf(stderr, "warning: no lora adapters will be applied.\n");
370      }
371  
372      // open input file
373      struct llama_file fin(params->fn_model_base.c_str(), "rb");
374      if (!fin.fp) {
375          die_fmt("Could not open file '%s'\n", params->fn_model_base.c_str());
376      }
377  
378      // open base model gguf, read tensors without their data
379      struct ggml_context * ctx_in;
380      struct gguf_init_params params_gguf;
381      params_gguf.no_alloc = true;
382      params_gguf.ctx      = &ctx_in;
383      struct gguf_context * gguf_in = gguf_init_from_file(params->fn_model_base.c_str(), params_gguf);
384  
385      // create new gguf
386      struct gguf_context * gguf_out = gguf_init_empty();
387  
388      // copy meta data from base model: kv and tensors
389      gguf_set_kv(gguf_out, gguf_in);
390      int n_tensors = gguf_get_n_tensors(gguf_in);
391      for (int i=0; i < n_tensors; ++i) {
392          const char * name = gguf_get_tensor_name(gguf_in, i);
393          struct ggml_tensor * tensor = ggml_get_tensor(ctx_in, name);
394          gguf_add_tensor(gguf_out, tensor);
395      }
396  
397      // create output file
398      struct llama_file fout(params->fn_model_out.c_str(), "wb");
399      if (!fout.fp) {
400          die_fmt("Could not create file '%s'\n", params->fn_model_out.c_str());
401      }
402  
403      // write gguf meta data
404      std::vector<uint8_t> meta;
405      meta.resize(gguf_get_meta_size(gguf_out));
406      gguf_get_meta_data(gguf_out, meta.data());
407      fout.write_raw(meta.data(), meta.size());
408  
409      std::vector<uint8_t> data;
410      std::vector<uint8_t> padding;
411      for (int i=0; i < n_tensors; ++i) {
412          const char * name = gguf_get_tensor_name(gguf_in, i);
413          struct ggml_tensor * tensor = ggml_get_tensor(ctx_in, name);
414  
415          // read tensor data
416          data.resize(ggml_nbytes(tensor));
417          tensor->data = data.data();
418          size_t offset = gguf_get_tensor_offset(gguf_in, i);
419          fin.seek(offset + meta.size(), SEEK_SET);
420          fin.read_raw(data.data(), data.size());
421  
422          // apply all loras
423          for (size_t k = 0; k < loras.size(); ++k) {
424              apply_lora(tensor, loras[k], params->n_threads);
425          }
426  
427          // write tensor data + padding
428          padding.clear();
429          padding.resize(GGML_PAD(data.size(), gguf_get_alignment(gguf_out)) - data.size(), 0);
430  
431          GGML_ASSERT(fout.tell() == offset + meta.size());
432          // fout.seek(offset + meta.size(), SEEK_SET);
433          fout.write_raw(data.data(), data.size());
434          fout.write_raw(padding.data(), padding.size());
435  
436          if (i % 2 == 0) {
437              printf(".");
438          }
439      }
440      printf("\n");
441  
442      // close gguf
443      gguf_free(gguf_out);
444      gguf_free(gguf_in);
445  
446      // free loras
447      for (size_t i = 0; i < loras.size(); ++i) {
448          free_lora(loras[i]);
449      }
450  }
451  
452  int main(int argc, char ** argv) {
453      struct export_lora_params params = get_default_export_lora_params();
454  
455      if (!export_lora_params_parse(argc, argv, &params)) {
456          return 1;
457      }
458  
459      export_lora(&params);
460  
461      return 0;
462  }