/ examples / 21_Export_and_run_other_machine_learning_models.ipynb
21_Export_and_run_other_machine_learning_models.ipynb
  1  {
  2    "nbformat": 4,
  3    "nbformat_minor": 0,
  4    "metadata": {
  5      "accelerator": "GPU",
  6      "colab": {
  7        "name": "21 - Export and run other machine learning models",
  8        "provenance": [],
  9        "collapsed_sections": []
 10      },
 11      "kernelspec": {
 12        "display_name": "Python 3",
 13        "name": "python3"
 14      }
 15    },
 16    "cells": [
 17      {
 18        "cell_type": "markdown",
 19        "metadata": {
 20          "id": "4Pjmz-RORV8E"
 21        },
 22        "source": [
 23          "# Export and run other machine learning models\n",
 24          "\n",
 25          "txtai primarily has support for [Hugging Face Transformers](https://github.com/huggingface/transformers) and [ONNX](https://github.com/microsoft/onnxruntime) models. This enables txtai to hook into the rich model framework available in Python, export this functionality via the API to other languages (JavaScript, Java, Go, Rust) and even export and natively load models with ONNX.\n",
 26          "\n",
 27          "What about other machine learning frameworks? Say we have an existing TF-IDF + Logistic Regression model that has been well tuned. Can this model be exported to ONNX and used in txtai for labeling and similarity queries? Or what about a simple PyTorch text classifier? Yes, both of these can be done!\n",
 28          "\n",
 29          "With the [onnxmltools](https://github.com/onnx/onnxmltools) library, traditional models from [scikit-learn](https://scikit-learn.org/stable/), [XGBoost](https://xgboost.readthedocs.io/en/latest/) and others can be exported to ONNX and loaded with txtai. Additionally, Hugging Face's trainer module can train generic PyTorch modules. This notebook will walk through all these examples.\n",
 30          "\n"
 31        ]
 32      },
 33      {
 34        "cell_type": "markdown",
 35        "metadata": {
 36          "id": "Dk31rbYjSTYm"
 37        },
 38        "source": [
 39          "# Install dependencies\n",
 40          "\n",
 41          "Install `txtai` and all dependencies."
 42        ]
 43      },
 44      {
 45        "cell_type": "code",
 46        "metadata": {
 47          "id": "XMQuuun2R06J"
 48        },
 49        "source": [
 50          "%%capture\n",
 51          "!pip install git+https://github.com/neuml/txtai#egg=txtai[pipeline,similarity] datasets"
 52        ],
 53        "execution_count": null,
 54        "outputs": []
 55      },
 56      {
 57        "cell_type": "markdown",
 58        "metadata": {
 59          "id": "r6nmtieHdMfr"
 60        },
 61        "source": [
 62          "# Train a TF-IDF + Logistic Regression model\n",
 63          "\n",
 64          "For this example, we'll load the emotion dataset from Hugging Face datasets and build a TF-IDF + Logistic Regression model with scikit-learn.\n",
 65          "\n",
 66          "The emotion dataset has the following labels:\n",
 67          "\n",
 68          "- sadness (0)\n",
 69          "- joy (1)\n",
 70          "- love (2)\n",
 71          "- anger (3)\n",
 72          "- fear (4)\n",
 73          "- surprise (5)\n"
 74        ]
 75      },
 76      {
 77        "cell_type": "code",
 78        "metadata": {
 79          "id": "pg9-tUxEdRfk"
 80        },
 81        "source": [
 82          "from datasets import load_dataset\n",
 83          "\n",
 84          "from sklearn.feature_extraction.text import TfidfVectorizer\n",
 85          "from sklearn.linear_model import LogisticRegression\n",
 86          "from sklearn.pipeline import Pipeline\n",
 87          "\n",
 88          "ds = load_dataset(\"emotion\")\n",
 89          "\n",
 90          "# Train the model\n",
 91          "pipeline = Pipeline([\n",
 92          "    ('tfidf', TfidfVectorizer()),\n",
 93          "    ('lr', LogisticRegression(max_iter=250))\n",
 94          "])\n",
 95          "\n",
 96          "pipeline.fit(ds[\"train\"][\"text\"], ds[\"train\"][\"label\"])\n",
 97          "\n",
 98          "# Determine accuracy on validation set\n",
 99          "results = pipeline.predict(ds[\"validation\"][\"text\"])\n",
100          "labels = ds[\"validation\"][\"label\"]\n",
101          "\n",
102          "results = [results[x] == label for x, label in enumerate(labels)]\n",
103          "print(\"Accuracy =\", sum(results) / len(ds[\"validation\"]))"
104        ],
105        "execution_count": null,
106        "outputs": [
107          {
108            "output_type": "stream",
109            "name": "stderr",
110            "text": [
111              "Using custom data configuration default\n",
112              "Reusing dataset emotion (/root/.cache/huggingface/datasets/emotion/default/0.0.0/348f63ca8e27b3713b6c04d723efe6d824a56fb3d1449794716c0f0296072705)\n"
113            ]
114          },
115          {
116            "output_type": "stream",
117            "name": "stdout",
118            "text": [
119              "Accuracy = 0.8595\n"
120            ]
121          }
122        ]
123      },
124      {
125        "cell_type": "markdown",
126        "metadata": {
127          "id": "49jZD4jQgdBg"
128        },
129        "source": [
130          "86% accuracy - not too bad! While we all get caught up in deep learning and advanced methods, good ole TF-IDF + Logistic Regression is still a solid performer and runs much faster. If that level of accuracy works, no reason to overcomplicate things."
131        ]
132      },
133      {
134        "cell_type": "markdown",
135        "metadata": {
136          "id": "zZtHxNSwFNGC"
137        },
138        "source": [
139          "# Export and load with txtai\n",
140          "\n",
141          "The next section exports this model to ONNX and shows how the model can be used for similarity queries. "
142        ]
143      },
144      {
145        "cell_type": "code",
146        "metadata": {
147          "colab": {
148            "base_uri": "https://localhost:8080/"
149          },
150          "id": "JBeScS5dFNeW",
151          "outputId": "e1b5cbf4-87dd-4598-e7ee-14e36cf31a7c"
152        },
153        "source": [
154          "from txtai.pipeline import Labels, MLOnnx, Similarity\n",
155          "\n",
156          "def tokenize(inputs, **kwargs):\n",
157          "    if isinstance(inputs, str):\n",
158          "        inputs = [inputs]\n",
159          "\n",
160          "    return {\"input_ids\": [[x] for x in inputs]}\n",
161          "\n",
162          "def query(model, tokenizer, multilabel=False):\n",
163          "    # Load models into similarity pipeline\n",
164          "    similarity = Similarity((model, tokenizer), dynamic=False)\n",
165          "\n",
166          "    # Add labels to model\n",
167          "    similarity.pipeline.model.config.id2label = {0: \"sadness\", 1: \"joy\", 2: \"love\", 3: \"anger\", 4: \"fear\", 5: \"surprise\"}\n",
168          "    similarity.pipeline.model.config.label2id = dict((v, k) for k, v in similarity.pipeline.model.config.id2label.items())\n",
169          "\n",
170          "    inputs = [\"that caught me off guard\", \"I didn t see that coming\", \"i feel bad\", \"What a wonderful goal!\"]\n",
171          "    scores = similarity(\"joy\", inputs, multilabel)\n",
172          "    for uid, score in scores[:5]:\n",
173          "        print(inputs[uid], score)\n",
174          "\n",
175          "# Export to ONNX\n",
176          "onnx = MLOnnx()\n",
177          "model = onnx(pipeline)\n",
178          "\n",
179          "# Create labels pipeline using scikit-learn ONNX model\n",
180          "sklabels = Labels((model, tokenize), dynamic=False)\n",
181          "\n",
182          "# Add labels to model\n",
183          "sklabels.pipeline.model.config.id2label = {0: \"sadness\", 1: \"joy\", 2: \"love\", 3: \"anger\", 4: \"fear\", 5: \"surprise\"}\n",
184          "sklabels.pipeline.model.config.label2id = dict((v, k) for k, v in sklabels.pipeline.model.config.id2label.items())\n",
185          "\n",
186          "# Run test query using model\n",
187          "query(model, tokenize, None)"
188        ],
189        "execution_count": null,
190        "outputs": [
191          {
192            "output_type": "stream",
193            "name": "stdout",
194            "text": [
195              "What a wonderful goal! 0.909473717212677\n",
196              "I didn t see that coming 0.47113093733787537\n",
197              "that caught me off guard 0.42067453265190125\n",
198              "i feel bad 0.019547615200281143\n"
199            ]
200          }
201        ]
202      },
203      {
204        "cell_type": "markdown",
205        "metadata": {
206          "id": "d-y8gFJwCwKN"
207        },
208        "source": [
209          "txtai can use a standard text classification model for similarity queries, where the label(s) are a list of fixed queries. The output above shows the best results for the query \"joy\"."
210        ]
211      },
212      {
213        "cell_type": "markdown",
214        "metadata": {
215          "id": "cbqwX7GgKBkf"
216        },
217        "source": [
218          "# Train a PyTorch model\n",
219          "\n",
220          "The next section defines a simple PyTorch text classifier. The transformers library has a trainer package that supports training PyTorch models, assuming some standard conventions/naming is used. "
221        ]
222      },
223      {
224        "cell_type": "code",
225        "metadata": {
226          "colab": {
227            "base_uri": "https://localhost:8080/",
228            "height": 239
229          },
230          "id": "k8PkTlBLKBTy",
231          "outputId": "4f48bfb2-2f16-45e3-d3e6-f1a2a747fd09"
232        },
233        "source": [
234          "# Set predictable seeds\n",
235          "import os\n",
236          "import random\n",
237          "import torch\n",
238          "\n",
239          "import numpy as np\n",
240          "\n",
241          "from torch import nn\n",
242          "from torch.nn import CrossEntropyLoss\n",
243          "from transformers import AutoConfig, AutoTokenizer\n",
244          "\n",
245          "from txtai.models import Registry\n",
246          "from txtai.pipeline import HFTrainer\n",
247          "\n",
248          "from transformers.modeling_outputs import SequenceClassifierOutput\n",
249          "\n",
250          "def seed(seed=42):\n",
251          "    random.seed(seed)\n",
252          "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
253          "    np.random.seed(seed)\n",
254          "    torch.manual_seed(seed)\n",
255          "    torch.cuda.manual_seed(seed)\n",
256          "    torch.backends.cudnn.deterministic = True\n",
257          "\n",
258          "class Simple(nn.Module):\n",
259          "    def __init__(self, vocab, dimensions, labels):\n",
260          "        super().__init__()\n",
261          "\n",
262          "        self.config = AutoConfig.from_pretrained(\"bert-base-uncased\")\n",
263          "        self.labels = labels\n",
264          "\n",
265          "        self.embedding = nn.EmbeddingBag(vocab, dimensions)\n",
266          "        self.classifier = nn.Linear(dimensions, labels)\n",
267          "        self.init_weights()\n",
268          "\n",
269          "    def init_weights(self):\n",
270          "        initrange = 0.5\n",
271          "        self.embedding.weight.data.uniform_(-initrange, initrange)\n",
272          "        self.classifier.weight.data.uniform_(-initrange, initrange)\n",
273          "        self.classifier.bias.data.zero_()\n",
274          "\n",
275          "    def forward(self, input_ids=None, labels=None, **kwargs):\n",
276          "        embeddings = self.embedding(input_ids)\n",
277          "        logits = self.classifier(embeddings)\n",
278          "\n",
279          "        loss = None\n",
280          "        if labels is not None:\n",
281          "            loss_fct = CrossEntropyLoss()\n",
282          "            loss = loss_fct(logits.view(-1, self.labels), labels.view(-1))\n",
283          "\n",
284          "        return SequenceClassifierOutput(\n",
285          "            loss=loss,\n",
286          "            logits=logits,\n",
287          "        )\n",
288          "\n",
289          "# Set seed for reproducibility\n",
290          "seed()\n",
291          "\n",
292          "# Define model\n",
293          "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
294          "model = Simple(tokenizer.vocab_size, 128, len(ds[\"train\"].unique(\"label\")))\n",
295          "\n",
296          "# Train model\n",
297          "train = HFTrainer()\n",
298          "model, tokenizer = train((model, tokenizer), ds[\"train\"], per_device_train_batch_size=8, learning_rate=1e-3, num_train_epochs=15, logging_steps=10000)\n",
299          "\n",
300          "# Register custom model to fully support pipelines\n",
301          "Registry.register(model)\n",
302          "\n",
303          "# Create labels pipeline using PyTorch model\n",
304          "thlabels = Labels((model, tokenizer), dynamic=False)\n",
305          "\n",
306          "# Determine accuracy on validation set\n",
307          "results = [row[\"label\"] == thlabels(row[\"text\"])[0][0] for row in ds[\"validation\"]]\n",
308          "print(\"Accuracy = \", sum(results) / len(ds[\"validation\"]))"
309        ],
310        "execution_count": null,
311        "outputs": [
312          {
313            "output_type": "stream",
314            "name": "stderr",
315            "text": [
316              "Loading cached processed dataset at /root/.cache/huggingface/datasets/emotion/default/0.0.0/348f63ca8e27b3713b6c04d723efe6d824a56fb3d1449794716c0f0296072705/cache-a983327c4471f5aa.arrow\n"
317            ]
318          },
319          {
320            "output_type": "display_data",
321            "data": {
322              "text/html": [
323                "\n",
324                "    <div>\n",
325                "      \n",
326                "      <progress value='30000' max='30000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
327                "      [30000/30000 02:28, Epoch 15/15]\n",
328                "    </div>\n",
329                "    <table border=\"1\" class=\"dataframe\">\n",
330                "  <thead>\n",
331                "    <tr style=\"text-align: left;\">\n",
332                "      <th>Step</th>\n",
333                "      <th>Training Loss</th>\n",
334                "    </tr>\n",
335                "  </thead>\n",
336                "  <tbody>\n",
337                "    <tr>\n",
338                "      <td>10000</td>\n",
339                "      <td>1.017600</td>\n",
340                "    </tr>\n",
341                "    <tr>\n",
342                "      <td>20000</td>\n",
343                "      <td>0.286200</td>\n",
344                "    </tr>\n",
345                "    <tr>\n",
346                "      <td>30000</td>\n",
347                "      <td>0.152500</td>\n",
348                "    </tr>\n",
349                "  </tbody>\n",
350                "</table><p>"
351              ],
352              "text/plain": [
353                "<IPython.core.display.HTML object>"
354              ]
355            },
356            "metadata": {}
357          },
358          {
359            "output_type": "stream",
360            "name": "stdout",
361            "text": [
362              "Accuracy =  0.883\n"
363            ]
364          }
365        ]
366      },
367      {
368        "cell_type": "markdown",
369        "metadata": {
370          "id": "nHQoJnrj60Pz"
371        },
372        "source": [
373          "88% accuracy this time. Pretty good for such a simple network and something that could definitely be improved upon. \n",
374          "\n",
375          "Once again let's run similarity queries using this model."
376        ]
377      },
378      {
379        "cell_type": "code",
380        "metadata": {
381          "colab": {
382            "base_uri": "https://localhost:8080/"
383          },
384          "id": "W5_NDInF5lFN",
385          "outputId": "38a2c126-63e9-40dc-f309-a29826b5b937"
386        },
387        "source": [
388          "query(model, tokenizer)"
389        ],
390        "execution_count": null,
391        "outputs": [
392          {
393            "output_type": "stream",
394            "name": "stdout",
395            "text": [
396              "What a wonderful goal! 1.0\n",
397              "that caught me off guard 0.9998751878738403\n",
398              "I didn t see that coming 0.7328283190727234\n",
399              "i feel bad 5.2972134609891875e-19\n"
400            ]
401          }
402        ]
403      },
404      {
405        "cell_type": "markdown",
406        "metadata": {
407          "id": "KmcsdIltDTwj"
408        },
409        "source": [
410          "Same result order as with the scikit-learn model with scoring variations which is expected given this is a completely different model."
411        ]
412      },
413      {
414        "cell_type": "markdown",
415        "metadata": {
416          "id": "-fNTi2jb68rv"
417        },
418        "source": [
419          "# Pooled embeddings\n",
420          "\n",
421          "The PyTorch model above consists of an embeddings layer with a linear classifier on top of it. What if we take that embeddings layer and use it for similarity queries? Let's give it a try."
422        ]
423      },
424      {
425        "cell_type": "code",
426        "metadata": {
427          "colab": {
428            "base_uri": "https://localhost:8080/"
429          },
430          "id": "J1yhfHKC7N7L",
431          "outputId": "11567948-769a-44df-9057-9fe9837a73dd"
432        },
433        "source": [
434          "from txtai.embeddings import Embeddings\n",
435          "\n",
436          "class SimpleEmbeddings(nn.Module):\n",
437          "    def __init__(self, embeddings):\n",
438          "        super().__init__()\n",
439          "\n",
440          "        self.embeddings = embeddings\n",
441          "\n",
442          "    def forward(self, input_ids=None, **kwargs):\n",
443          "        return (self.embeddings(input_ids),)\n",
444          "\n",
445          "embeddings = Embeddings({\"method\": \"pooling\", \"path\": SimpleEmbeddings(model.embedding), \"tokenizer\": \"bert-base-uncased\"})\n",
446          "print(embeddings.similarity(\"mad\", [\"Glad you found it\", \"Happy to see you\", \"I'm angry\"]))"
447        ],
448        "execution_count": null,
449        "outputs": [
450          {
451            "output_type": "stream",
452            "name": "stdout",
453            "text": [
454              "[(2, 0.8323876857757568), (1, -0.11010512709617615), (0, -0.16152513027191162)]\n"
455            ]
456          }
457        ]
458      },
459      {
460        "cell_type": "markdown",
461        "metadata": {
462          "id": "0kTUEIcmBNuV"
463        },
464        "source": [
465          "Definitely looks like the embeddings have stored knowledge. Could these embeddings be good enough to build a semantic search index, especially for sentiment based data, given the training dataset? Possibly. It certainly would run faster than a standard transformer model (see below). "
466        ]
467      },
468      {
469        "cell_type": "markdown",
470        "metadata": {
471          "id": "V7nAl3WtkBNK"
472        },
473        "source": [
474          "# Train a transformer model and compare accuracy/speed\n",
475          "\n",
476          "Let's train a standard transformer sequence classifier and compare the accuracy/speed between the two. "
477        ]
478      },
479      {
480        "cell_type": "code",
481        "metadata": {
482          "colab": {
483            "base_uri": "https://localhost:8080/",
484            "height": 274
485          },
486          "id": "46fMiJrAIBu4",
487          "outputId": "f0512cf8-3bc2-41ed-caff-e1541403f2a5"
488        },
489        "source": [
490          "train = HFTrainer()\n",
491          "model, tokenizer = train(\"microsoft/xtremedistil-l6-h384-uncased\", ds[\"train\"], logging_steps=2000)\n",
492          "\n",
493          "tflabels = Labels((model, tokenizer), dynamic=False)\n",
494          "\n",
495          "# Determine accuracy on validation set\n",
496          "results = [row[\"label\"] == tflabels(row[\"text\"])[0][0] for row in ds[\"validation\"]]\n",
497          "print(\"Accuracy = \", sum(results) / len(ds[\"validation\"]))"
498        ],
499        "execution_count": null,
500        "outputs": [
501          {
502            "output_type": "stream",
503            "name": "stderr",
504            "text": [
505              "Loading cached processed dataset at /root/.cache/huggingface/datasets/emotion/default/0.0.0/348f63ca8e27b3713b6c04d723efe6d824a56fb3d1449794716c0f0296072705/cache-98b7ef31bf6ca944.arrow\n",
506              "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/xtremedistil-l6-h384-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
507              "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
508            ]
509          },
510          {
511            "output_type": "display_data",
512            "data": {
513              "text/html": [
514                "\n",
515                "    <div>\n",
516                "      \n",
517                "      <progress value='6000' max='6000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
518                "      [6000/6000 07:13, Epoch 3/3]\n",
519                "    </div>\n",
520                "    <table border=\"1\" class=\"dataframe\">\n",
521                "  <thead>\n",
522                "    <tr style=\"text-align: left;\">\n",
523                "      <th>Step</th>\n",
524                "      <th>Training Loss</th>\n",
525                "    </tr>\n",
526                "  </thead>\n",
527                "  <tbody>\n",
528                "    <tr>\n",
529                "      <td>2000</td>\n",
530                "      <td>0.635500</td>\n",
531                "    </tr>\n",
532                "    <tr>\n",
533                "      <td>4000</td>\n",
534                "      <td>0.281700</td>\n",
535                "    </tr>\n",
536                "    <tr>\n",
537                "      <td>6000</td>\n",
538                "      <td>0.192600</td>\n",
539                "    </tr>\n",
540                "  </tbody>\n",
541                "</table><p>"
542              ],
543              "text/plain": [
544                "<IPython.core.display.HTML object>"
545              ]
546            },
547            "metadata": {}
548          },
549          {
550            "output_type": "stream",
551            "name": "stdout",
552            "text": [
553              "Accuracy =  0.926\n"
554            ]
555          }
556        ]
557      },
558      {
559        "cell_type": "markdown",
560        "metadata": {
561          "id": "ycvozGzPmlbS"
562        },
563        "source": [
564          "As expected, the accuracy is better. The model above is a distilled model and even better accuracy can be obtained with a model like \"roberta-base\" with the tradeoff being increased training/inference time. \n",
565          "\n",
566          "Speaking of speed, let's compare the speed of these models."
567        ]
568      },
569      {
570        "cell_type": "code",
571        "metadata": {
572          "colab": {
573            "base_uri": "https://localhost:8080/"
574          },
575          "id": "nWQMRQm0NwdN",
576          "outputId": "4a49406c-b4eb-46b1-edab-de01c15fdccb"
577        },
578        "source": [
579          "import time\n",
580          "\n",
581          "# Test inputs\n",
582          "inputs = ds[\"test\"][\"text\"]\n",
583          "print(\"Testing speed of %d items\" % len(inputs))\n",
584          "\n",
585          "start = time.time()\n",
586          "r1 = sklabels(inputs, multilabel=None)\n",
587          "print(\"TF-IDF + Logistic Regression time =\", time.time() - start)\n",
588          "\n",
589          "start = time.time()\n",
590          "r2 = thlabels(inputs)\n",
591          "print(\"PyTorch time =\", time.time() - start)\n",
592          "\n",
593          "start = time.time()\n",
594          "r3 = tflabels(inputs)\n",
595          "print(\"Transformers time =\", time.time() - start, \"\\n\")\n",
596          "\n",
597          "# Compare model results\n",
598          "for x in range(5):\n",
599          "  print(\"index: %d\" % x)\n",
600          "  print(r1[x][0])\n",
601          "  print(r2[x][0])\n",
602          "  print(r3[x][0], \"\\n\")"
603        ],
604        "execution_count": null,
605        "outputs": [
606          {
607            "output_type": "stream",
608            "name": "stdout",
609            "text": [
610              "Testing speed of 2000 items\n",
611              "TF-IDF + Logistic Regression time = 1.0483319759368896\n"
612            ]
613          },
614          {
615            "output_type": "stream",
616            "name": "stdout",
617            "text": [
618              "PyTorch time = 2.0001697540283203\n",
619              "Transformers time = 13.71584439277649 \n",
620              "\n",
621              "index: 0\n",
622              "(0, 0.7258279323577881)\n",
623              "(0, 1.0)\n",
624              "(0, 0.998375654220581) \n",
625              "\n",
626              "index: 1\n",
627              "(0, 0.854256272315979)\n",
628              "(0, 1.0)\n",
629              "(0, 0.9983494281768799) \n",
630              "\n",
631              "index: 2\n",
632              "(0, 0.6306578516960144)\n",
633              "(0, 0.9999700784683228)\n",
634              "(0, 0.9982945322990417) \n",
635              "\n",
636              "index: 3\n",
637              "(1, 0.554378092288971)\n",
638              "(1, 0.9998960494995117)\n",
639              "(1, 0.99846351146698) \n",
640              "\n",
641              "index: 4\n",
642              "(0, 0.8961835503578186)\n",
643              "(0, 1.0)\n",
644              "(0, 0.9984095692634583) \n",
645              "\n"
646            ]
647          }
648        ]
649      },
650      {
651        "cell_type": "markdown",
652        "metadata": {
653          "id": "1YMTyqIWDiOB"
654        },
655        "source": [
656          "# Wrapping up\n",
657          "\n",
658          "This notebook showed how frameworks outside of Transformers and ONNX can be used as models in txtai.\n",
659          "\n",
660          "As seen in the section above, TF-IDF + Logistic Regression is 16 times faster than a distilled Transformers model. A simple PyTorch network is 8 times faster. Depending on your accuracy requirements, it may make sense to use a simpler model to get better runtime performance."
661        ]
662      }
663    ]
664  }