/ examples / 07_Apply_labels_with_zero_shot_classification.ipynb
07_Apply_labels_with_zero_shot_classification.ipynb
  1  {
  2    "nbformat": 4,
  3    "nbformat_minor": 0,
  4    "metadata": {
  5      "colab": {
  6        "provenance": []
  7      },
  8      "kernelspec": {
  9        "name": "python3",
 10        "display_name": "Python 3"
 11      }
 12    },
 13    "cells": [
 14      {
 15        "cell_type": "markdown",
 16        "metadata": {
 17          "id": "4Pjmz-RORV8E"
 18        },
 19        "source": [
 20          "# Apply labels with zero-shot classification\n",
 21          "\n",
 22          "This notebook shows how zero-shot classification can be used to perform text classification, labeling and topic modeling. txtai provides a light-weight wrapper around the zero-shot-classification pipeline in Hugging Face Transformers. This method works impressively well out of the box. Kudos to the Hugging Face team for the phenomenal work on zero-shot classification!\n",
 23          "\n",
 24          "The examples in this notebook pick the best matching label using a list of labels for a snippet of text.\n",
 25          "\n",
 26          "[tldrstory](https://github.com/neuml/tldrstory) has full-stack implementation of a zero-shot classification system using Streamlit, FastAPI and Hugging Face Transformers. There is also a [Medium article describing tldrstory](https://towardsdatascience.com/tldrstory-ai-powered-understanding-of-headlines-and-story-text-fc86abd702fc) and zero-shot classification. \n"
 27        ]
 28      },
 29      {
 30        "cell_type": "markdown",
 31        "metadata": {
 32          "id": "Dk31rbYjSTYm"
 33        },
 34        "source": [
 35          "# Install dependencies\n",
 36          "\n",
 37          "Install `txtai` and all dependencies."
 38        ]
 39      },
 40      {
 41        "cell_type": "code",
 42        "metadata": {
 43          "id": "XMQuuun2R06J"
 44        },
 45        "source": [
 46          "%%capture\n",
 47          "!pip install git+https://github.com/neuml/txtai"
 48        ],
 49        "execution_count": null,
 50        "outputs": []
 51      },
 52      {
 53        "cell_type": "markdown",
 54        "metadata": {
 55          "id": "PNPJ95cdTKSS"
 56        },
 57        "source": [
 58          "# Create a Labels instance\n",
 59          "\n",
 60          "The Labels instance is the main entrypoint for zero-shot classification. This is a light-weight wrapper around the zero-shot-classification pipeline in Hugging Face Transformers.\n",
 61          "\n",
 62          "In addition to the default model, additional models can be found on the [Hugging Face model hub](https://huggingface.co/models?search=mnli).\n"
 63        ]
 64      },
 65      {
 66        "cell_type": "code",
 67        "metadata": {
 68          "id": "nTDwXOUeTH2-"
 69        },
 70        "source": [
 71          "%%capture\n",
 72          "\n",
 73          "from txtai.pipeline import Labels\n",
 74          "\n",
 75          "# Create labels model\n",
 76          "labels = Labels()\n",
 77          "\n",
 78          "# Alternate models can be used via passing the model path as shown below\n",
 79          "# labels = Labels(\"roberta-large-mnli\")"
 80        ],
 81        "execution_count": null,
 82        "outputs": []
 83      },
 84      {
 85        "cell_type": "markdown",
 86        "metadata": {
 87          "id": "-vGR_piwZZO6"
 88        },
 89        "source": [
 90          "# Applying labels to text\n",
 91          "\n",
 92          "The example below shows how a zero-shot classifier can be applied to arbitary text. The default model for the zero-shot classification pipeline is *bart-large-mnli*. \n",
 93          "\n",
 94          "Look at the results below. It's nothing short of amazing✨ how well it performs. These aren't all simple even for a human. For example, intercepted was purposely picked as that is more common in football than basketball. The amount of knowledge stored in larger Transformer models continues to impress me. "
 95        ]
 96      },
 97      {
 98        "cell_type": "code",
 99        "metadata": {
100          "colab": {
101            "base_uri": "https://localhost:8080/"
102          },
103          "id": "-K2YJJzsVtfq",
104          "outputId": "7a1edf58-15e0-46c8-958e-3a8e6045f802"
105        },
106        "source": [
107          "data = [\"Dodgers lose again, give up 3 HRs in a loss to the Giants\",\n",
108          "        \"Giants 5 Cardinals 4 final in extra innings\",\n",
109          "        \"Dodgers drop Game 2 against the Giants, 5-4\",\n",
110          "        \"Flyers 4 Lightning 1 final. 45 saves for the Lightning.\",\n",
111          "        \"Slashing, penalty, 2 minute power play coming up\",\n",
112          "        \"What a stick save!\",\n",
113          "        \"Leads the NFL in sacks with 9.5\",\n",
114          "        \"UCF 38 Temple 13\",\n",
115          "        \"With the 30 yard completion, down to the 10 yard line\",\n",
116          "        \"Drains the 3pt shot!!, 0:15 remaining in the game\",\n",
117          "        \"Intercepted! Drives down the court and shoots for the win\",\n",
118          "        \"Massive dunk!!! they are now up by 15 with 2 minutes to go\"]\n",
119          "\n",
120          "# List of labels\n",
121          "tags = [\"Baseball\", \"Football\", \"Hockey\", \"Basketball\"]\n",
122          "\n",
123          "print(\"%-75s %s\" % (\"Text\", \"Label\"))\n",
124          "print(\"-\" * 100)\n",
125          "\n",
126          "for text in data:\n",
127          "    print(\"%-75s %s\" % (text, tags[labels(text, tags)[0][0]]))"
128        ],
129        "execution_count": null,
130        "outputs": [
131          {
132            "output_type": "stream",
133            "name": "stdout",
134            "text": [
135              "Text                                                                        Label\n",
136              "----------------------------------------------------------------------------------------------------\n",
137              "Dodgers lose again, give up 3 HRs in a loss to the Giants                   Baseball\n",
138              "Giants 5 Cardinals 4 final in extra innings                                 Baseball\n",
139              "Dodgers drop Game 2 against the Giants, 5-4                                 Baseball\n",
140              "Flyers 4 Lightning 1 final. 45 saves for the Lightning.                     Hockey\n",
141              "Slashing, penalty, 2 minute power play coming up                            Hockey\n",
142              "What a stick save!                                                          Hockey\n",
143              "Leads the NFL in sacks with 9.5                                             Football\n",
144              "UCF 38 Temple 13                                                            Football\n",
145              "With the 30 yard completion, down to the 10 yard line                       Football\n",
146              "Drains the 3pt shot!!, 0:15 remaining in the game                           Basketball\n",
147              "Intercepted! Drives down the court and shoots for the win                   Basketball\n",
148              "Massive dunk!!! they are now up by 15 with 2 minutes to go                  Basketball\n"
149            ]
150          }
151        ]
152      },
153      {
154        "cell_type": "markdown",
155        "metadata": {
156          "id": "t-tGAzCxsHLy"
157        },
158        "source": [
159          "# Let's try emoji πŸ˜€\n",
160          "\n",
161          "Does the model have knowledge of emoji? Check out the run below, sure looks like it does! Notice the labels are applied based on the perspective from which the information is presented. "
162        ]
163      },
164      {
165        "cell_type": "code",
166        "metadata": {
167          "colab": {
168            "base_uri": "https://localhost:8080/"
169          },
170          "id": "uIf064M9pbjn",
171          "outputId": "1d104014-e9ca-4c89-d259-2b5b231840ad"
172        },
173        "source": [
174          "tags = [\"πŸ˜€\", \"😑\"]\n",
175          "\n",
176          "print(\"%-75s %s\" % (\"Text\", \"Label\"))\n",
177          "print(\"-\" * 100)\n",
178          "\n",
179          "for text in data:\n",
180          "    print(\"%-75s %s\" % (text, tags[labels(text, tags)[0][0]]))"
181        ],
182        "execution_count": null,
183        "outputs": [
184          {
185            "output_type": "stream",
186            "name": "stdout",
187            "text": [
188              "Text                                                                        Label\n",
189              "----------------------------------------------------------------------------------------------------\n",
190              "Dodgers lose again, give up 3 HRs in a loss to the Giants                   😑\n",
191              "Giants 5 Cardinals 4 final in extra innings                                 πŸ˜€\n",
192              "Dodgers drop Game 2 against the Giants, 5-4                                 😑\n",
193              "Flyers 4 Lightning 1 final. 45 saves for the Lightning.                     πŸ˜€\n",
194              "Slashing, penalty, 2 minute power play coming up                            😑\n",
195              "What a stick save!                                                          πŸ˜€\n",
196              "Leads the NFL in sacks with 9.5                                             πŸ˜€\n",
197              "UCF 38 Temple 13                                                            πŸ˜€\n",
198              "With the 30 yard completion, down to the 10 yard line                       πŸ˜€\n",
199              "Drains the 3pt shot!!, 0:15 remaining in the game                           πŸ˜€\n",
200              "Intercepted! Drives down the court and shoots for the win                   πŸ˜€\n",
201              "Massive dunk!!! they are now up by 15 with 2 minutes to go                  πŸ˜€\n"
202            ]
203          }
204        ]
205      }
206    ]
207  }