03_Build_an_Embeddings_index_from_a_data_source.ipynb
1 { 2 "cells": [ 3 { 4 "cell_type": "markdown", 5 "metadata": { 6 "id": "WDbhGHtG8jFE" 7 }, 8 "source": [ 9 "# Build an Embeddings index from a data source\n", 10 "\n", 11 "In Part 1, we gave a general overview of txtai, the backing technology and examples of how to use it for similarity searches. Part 2 covered an embedding index with a larger dataset.\n", 12 "\n", 13 "For real world large-scale use cases, data is often stored in a database (Elasticsearch, SQL, MongoDB, files, etc). Here we'll show how to read from SQLite, build an Embedding index and run queries against the generated Embeddings index.\n", 14 "\n", 15 "This example covers functionality found in the [paperai](https://github.com/neuml/paperai) library. See that library for a full solution that can be used with the dataset discussed below." 16 ] 17 }, 18 { 19 "cell_type": "markdown", 20 "metadata": { 21 "id": "UQ0fCwXn9bcH" 22 }, 23 "source": [ 24 "# Install dependencies\n", 25 "\n", 26 "Install `txtai` and all dependencies." 27 ] 28 }, 29 { 30 "cell_type": "code", 31 "execution_count": null, 32 "metadata": { 33 "id": "czPYSA2Q9ZHO" 34 }, 35 "outputs": [], 36 "source": [ 37 "%%capture\n", 38 "!pip install git+https://github.com/neuml/txtai" 39 ] 40 }, 41 { 42 "cell_type": "markdown", 43 "metadata": { 44 "id": "SN9SCZKQ9fJF" 45 }, 46 "source": [ 47 "# Download data\n", 48 "\n", 49 "This example is going to work off a subset of the [CORD-19](https://www.semanticscholar.org/cord19) dataset. COVID-19 Open Research Dataset (CORD-19) is a free resource of scholarly articles, aggregated by a coalition of leading research groups, covering COVID-19 and the coronavirus family of viruses.\n", 50 "\n", 51 "The following download is a SQLite database generated from a [Kaggle notebook](https://www.kaggle.com/davidmezzetti/cord-19-slim/output). More information on this data format, can be found in the [CORD-19 Analysis](https://www.kaggle.com/davidmezzetti/cord-19-analysis-with-sentence-embeddings) notebook." 52 ] 53 }, 54 { 55 "cell_type": "code", 56 "execution_count": null, 57 "metadata": { 58 "id": "TONQ4_Kv9dtd" 59 }, 60 "outputs": [], 61 "source": [ 62 "%%capture\n", 63 "!wget https://github.com/neuml/txtai/releases/download/v1.1.0/tests.gz\n", 64 "!gunzip tests.gz\n", 65 "!mv tests articles.sqlite" 66 ] 67 }, 68 { 69 "cell_type": "markdown", 70 "metadata": { 71 "id": "_UxcC1-JGH-d" 72 }, 73 "source": [ 74 "# Build an embeddings index\n", 75 "\n", 76 "The following steps build an embeddings index using a vector model designed for medical papers, [PubMedBERT Embeddings](https://huggingface.co/NeuML/pubmedbert-base-embeddings)." 77 ] 78 }, 79 { 80 "cell_type": "code", 81 "execution_count": 2, 82 "metadata": { 83 "colab": { 84 "base_uri": "https://localhost:8080/" 85 }, 86 "id": "5PrrxGRPGHqX", 87 "outputId": "61bf7211-6757-4147-8f2f-e4d1ebe58e11" 88 }, 89 "outputs": [ 90 { 91 "name": "stdout", 92 "output_type": "stream", 93 "text": [ 94 "Iterated over 21499 total rows\n" 95 ] 96 } 97 ], 98 "source": [ 99 "import sqlite3\n", 100 "\n", 101 "import regex as re\n", 102 "\n", 103 "from txtai import Embeddings\n", 104 "\n", 105 "def stream():\n", 106 " # Connection to database file\n", 107 " db = sqlite3.connect(\"articles.sqlite\")\n", 108 " cur = db.cursor()\n", 109 "\n", 110 " # Select tagged sentences without a NLP label. NLP labels are set for non-informative sentences.\n", 111 " cur.execute(\"SELECT Id, Name, Text FROM sections WHERE (labels is null or labels NOT IN ('FRAGMENT', 'QUESTION')) AND tags is not null\")\n", 112 "\n", 113 " count = 0\n", 114 " for row in cur:\n", 115 " # Unpack row\n", 116 " uid, name, text = row\n", 117 "\n", 118 " # Only process certain document sections\n", 119 " if not name or not re.search(r\"background|(?<!.*?results.*?)discussion|introduction|reference\", name.lower()):\n", 120 " document = (uid, text, None)\n", 121 "\n", 122 " count += 1\n", 123 " if count % 1000 == 0:\n", 124 " print(\"Streamed %d documents\" % (count), end=\"\\r\")\n", 125 "\n", 126 " yield document\n", 127 "\n", 128 " print(\"Iterated over %d total rows\" % (count))\n", 129 "\n", 130 " # Free database resources\n", 131 " db.close()\n", 132 "\n", 133 "# Create embeddings index \n", 134 "embeddings = Embeddings(path=\"neuml/pubmedbert-base-embeddings\")\n", 135 "\n", 136 "# Build embeddings index\n", 137 "embeddings.index(stream())\n" 138 ] 139 }, 140 { 141 "cell_type": "markdown", 142 "metadata": { 143 "id": "zHk24su3e_gb" 144 }, 145 "source": [ 146 "# Query data\n", 147 "\n", 148 "The following runs a query against the embeddings index for the terms \"risk factors\". It finds the top 5 matches and returns the corresponding documents associated with each match." 149 ] 150 }, 151 { 152 "cell_type": "code", 153 "execution_count": 7, 154 "metadata": { 155 "colab": { 156 "base_uri": "https://localhost:8080/", 157 "height": 293 158 }, 159 "id": "CRbDhvvDKEl-", 160 "outputId": "774f8085-01db-49e2-f025-4fe68afca594" 161 }, 162 "outputs": [ 163 { 164 "data": { 165 "text/html": [ 166 "<table border=\"1\" class=\"dataframe\">\n", 167 " <thead>\n", 168 " <tr style=\"text-align: right;\">\n", 169 " <th>Title</th>\n", 170 " <th>Published</th>\n", 171 " <th>Reference</th>\n", 172 " <th>Match</th>\n", 173 " </tr>\n", 174 " </thead>\n", 175 " <tbody>\n", 176 " <tr>\n", 177 " <td>Management of osteoarthritis during COVID‐19 pandemic</td>\n", 178 " <td>2020-05-21 00:00:00</td>\n", 179 " <td>https://doi.org/10.1002/cpt.1910</td>\n", 180 " <td>Indeed, risk factors are sex, obesity, genetic factors and mechanical factors (3) .</td>\n", 181 " </tr>\n", 182 " <tr>\n", 183 " <td>Does apolipoprotein E genotype predict COVID-19 severity?</td>\n", 184 " <td>2020-04-27 00:00:00</td>\n", 185 " <td>https://doi.org/10.1093/qjmed/hcaa142</td>\n", 186 " <td>Risk factors associated with subsequent death include older age, hypertension, diabetes, ischemic heart disease, obesity and chronic lung disease; however, sometimes there are no obvious risk factors .</td>\n", 187 " </tr>\n", 188 " <tr>\n", 189 " <td>Prevalence and Impact of Myocardial Injury in Patients Hospitalized with COVID-19 Infection</td>\n", 190 " <td>2020-04-24 00:00:00</td>\n", 191 " <td>http://medrxiv.org/cgi/content/short/2020.04.20.20072702v1?rss=1</td>\n", 192 " <td>This risk was consistent across patients stratified by history of CVD, risk factors but no CVD, and neither CVD nor risk factors.</td>\n", 193 " </tr>\n", 194 " <tr>\n", 195 " <td>COVID-19 and associations with frailty and multimorbidity: a prospective analysis of UK Biobank participants</td>\n", 196 " <td>2020-07-23 00:00:00</td>\n", 197 " <td>https://www.ncbi.nlm.nih.gov/pubmed/32705587/</td>\n", 198 " <td>BACKGROUND: Frailty and multimorbidity have been suggested as risk factors for severe COVID-19 disease.</td>\n", 199 " </tr>\n", 200 " <tr>\n", 201 " <td>Risk Stratification for Healthcare workers during the CoViD-19 Pandemic; using demographics, co-morbid disease and clinical domain in order to assign clinical duties</td>\n", 202 " <td>2020-05-09 00:00:00</td>\n", 203 " <td>http://medrxiv.org/cgi/content/short/2020.05.05.20091967v1?rss=1</td>\n", 204 " <td>Vascular disease, diabetes and chronic pulmonary disease further increased risk.</td>\n", 205 " </tr>\n", 206 " </tbody>\n", 207 "</table>" 208 ], 209 "text/plain": [ 210 "<IPython.core.display.HTML object>" 211 ] 212 }, 213 "metadata": {}, 214 "output_type": "display_data" 215 } 216 ], 217 "source": [ 218 "import pandas as pd\n", 219 "\n", 220 "from IPython.display import display, HTML\n", 221 "\n", 222 "pd.set_option(\"display.max_colwidth\", None)\n", 223 "\n", 224 "db = sqlite3.connect(\"articles.sqlite\")\n", 225 "cur = db.cursor()\n", 226 "\n", 227 "results = []\n", 228 "for uid, score in embeddings.search(\"risk factors\", 5):\n", 229 " cur.execute(\"SELECT article, text FROM sections WHERE id = ?\", [uid])\n", 230 " uid, text = cur.fetchone()\n", 231 "\n", 232 " cur.execute(\"SELECT Title, Published, Reference from articles where id = ?\", [uid])\n", 233 " results.append(cur.fetchone() + (text,))\n", 234 "\n", 235 "# Free database resources\n", 236 "db.close()\n", 237 "\n", 238 "df = pd.DataFrame(results, columns=[\"Title\", \"Published\", \"Reference\", \"Match\"])\n", 239 "\n", 240 "# It has been reported that displaying HTML within VSCode doesn't work.\n", 241 "# When using VSCode, the data can be exported to an external HTML file to view.\n", 242 "# See example below.\n", 243 "\n", 244 "# htmlData = df.to_html(index=False)\n", 245 "# with open(\"data.html\", \"w\") as file:\n", 246 "# file.write(htmlData)\n", 247 "\n", 248 "display(HTML(df.to_html(index=False)))" 249 ] 250 }, 251 { 252 "cell_type": "markdown", 253 "metadata": { 254 "id": "XSf68I-ZfXOG" 255 }, 256 "source": [ 257 "# Extracting additional columns from query results\n", 258 "\n", 259 "The example above uses the Embeddings index to find the top 5 best matches. In addition to this, an Extractor instance (this will be explained further in part 5) is used to ask additional questions over the search results, creating a richer query response." 260 ] 261 }, 262 { 263 "cell_type": "code", 264 "execution_count": null, 265 "metadata": { 266 "id": "TLVOTQJchvTi" 267 }, 268 "outputs": [], 269 "source": [ 270 "%%capture\n", 271 "from txtai.pipeline import Extractor\n", 272 "\n", 273 "# Create extractor instance using qa model designed for the CORD-19 dataset\n", 274 "# Note: That extractive QA was a predecessor to Large Language Models (LLMs). LLMs likely will get better results.\n", 275 "extractor = Extractor(embeddings, \"NeuML/bert-small-cord19qa\")" 276 ] 277 }, 278 { 279 "cell_type": "code", 280 "execution_count": 9, 281 "metadata": { 282 "colab": { 283 "base_uri": "https://localhost:8080/", 284 "height": 293 285 }, 286 "id": "19fmKawThs6d", 287 "outputId": "b7cd40e3-a87c-419d-f520-b7795607cebc" 288 }, 289 "outputs": [ 290 { 291 "data": { 292 "text/html": [ 293 "<table border=\"1\" class=\"dataframe\">\n", 294 " <thead>\n", 295 " <tr style=\"text-align: right;\">\n", 296 " <th>Title</th>\n", 297 " <th>Published</th>\n", 298 " <th>Reference</th>\n", 299 " <th>Match</th>\n", 300 " <th>Risk Factors</th>\n", 301 " <th>Locations</th>\n", 302 " </tr>\n", 303 " </thead>\n", 304 " <tbody>\n", 305 " <tr>\n", 306 " <td>Management of osteoarthritis during COVID‐19 pandemic</td>\n", 307 " <td>2020-05-21 00:00:00</td>\n", 308 " <td>https://doi.org/10.1002/cpt.1910</td>\n", 309 " <td>Indeed, risk factors are sex, obesity, genetic factors and mechanical factors (3) .</td>\n", 310 " <td>sex, obesity, genetic factors and mechanical factors</td>\n", 311 " <td>hospitals and clinics</td>\n", 312 " </tr>\n", 313 " <tr>\n", 314 " <td>Does apolipoprotein E genotype predict COVID-19 severity?</td>\n", 315 " <td>2020-04-27 00:00:00</td>\n", 316 " <td>https://doi.org/10.1093/qjmed/hcaa142</td>\n", 317 " <td>Risk factors associated with subsequent death include older age, hypertension, diabetes, ischemic heart disease, obesity and chronic lung disease; however, sometimes there are no obvious risk factors .</td>\n", 318 " <td>None</td>\n", 319 " <td>None</td>\n", 320 " </tr>\n", 321 " <tr>\n", 322 " <td>Prevalence and Impact of Myocardial Injury in Patients Hospitalized with COVID-19 Infection</td>\n", 323 " <td>2020-04-24 00:00:00</td>\n", 324 " <td>http://medrxiv.org/cgi/content/short/2020.04.20.20072702v1?rss=1</td>\n", 325 " <td>This risk was consistent across patients stratified by history of CVD, risk factors but no CVD, and neither CVD nor risk factors.</td>\n", 326 " <td>neither CVD nor risk factors</td>\n", 327 " <td>Mount Sinai Health System</td>\n", 328 " </tr>\n", 329 " <tr>\n", 330 " <td>COVID-19 and associations with frailty and multimorbidity: a prospective analysis of UK Biobank participants</td>\n", 331 " <td>2020-07-23 00:00:00</td>\n", 332 " <td>https://www.ncbi.nlm.nih.gov/pubmed/32705587/</td>\n", 333 " <td>BACKGROUND: Frailty and multimorbidity have been suggested as risk factors for severe COVID-19 disease.</td>\n", 334 " <td>Frailty and multimorbidity</td>\n", 335 " <td>213 countries and territories</td>\n", 336 " </tr>\n", 337 " <tr>\n", 338 " <td>Risk Stratification for Healthcare workers during the CoViD-19 Pandemic; using demographics, co-morbid disease and clinical domain in order to assign clinical duties</td>\n", 339 " <td>2020-05-09 00:00:00</td>\n", 340 " <td>http://medrxiv.org/cgi/content/short/2020.05.05.20091967v1?rss=1</td>\n", 341 " <td>Vascular disease, diabetes and chronic pulmonary disease further increased risk.</td>\n", 342 " <td>Vascular disease, diabetes and chronic pulmonary disease</td>\n", 343 " <td>None</td>\n", 344 " </tr>\n", 345 " </tbody>\n", 346 "</table>" 347 ], 348 "text/plain": [ 349 "<IPython.core.display.HTML object>" 350 ] 351 }, 352 "metadata": {}, 353 "output_type": "display_data" 354 } 355 ], 356 "source": [ 357 "db = sqlite3.connect(\"articles.sqlite\")\n", 358 "cur = db.cursor()\n", 359 "\n", 360 "results = []\n", 361 "for uid, score in embeddings.search(\"risk factors\", 5):\n", 362 " cur.execute(\"SELECT article, text FROM sections WHERE id = ?\", [uid])\n", 363 " uid, text = cur.fetchone()\n", 364 "\n", 365 " # Get list of document text sections to use for the context\n", 366 " cur.execute(\"SELECT Name, Text FROM sections WHERE (labels is null or labels NOT IN ('FRAGMENT', 'QUESTION')) AND article = ? ORDER BY Id\", [uid])\n", 367 " texts = []\n", 368 " for name, txt in cur.fetchall():\n", 369 " if not name or not re.search(r\"background|(?<!.*?results.*?)discussion|introduction|reference\", name.lower()):\n", 370 " texts.append(txt)\n", 371 "\n", 372 " cur.execute(\"SELECT Title, Published, Reference from articles where id = ?\", [uid])\n", 373 " article = cur.fetchone()\n", 374 "\n", 375 " # Use QA extractor to derive additional columns\n", 376 " answers = extractor([(\"Risk Factors\", \"risk factors\", \"What risk factors?\", False),\n", 377 " (\"Locations\", \"hospital country\", \"What locations?\", False)], texts)\n", 378 "\n", 379 " results.append(article + (text,) + tuple([answer[1] for answer in answers]))\n", 380 "\n", 381 "# Free database resources\n", 382 "db.close()\n", 383 "\n", 384 "df = pd.DataFrame(results, columns=[\"Title\", \"Published\", \"Reference\", \"Match\", \"Risk Factors\", \"Locations\"])\n", 385 "display(HTML(df.to_html(index=False)))" 386 ] 387 }, 388 { 389 "cell_type": "markdown", 390 "metadata": { 391 "id": "ColTLy--rWfR" 392 }, 393 "source": [ 394 "In the example above, the Embeddings index is used to find the top N results for a given query. On top of that, a question-answer extractor is used to derive additional columns based on a list of questions. In this case, the \"Risk Factors\" and \"Location\" columns were pulled from the document text." 395 ] 396 } 397 ], 398 "metadata": { 399 "colab": { 400 "provenance": [] 401 }, 402 "kernelspec": { 403 "display_name": "local", 404 "language": "python", 405 "name": "python3" 406 }, 407 "language_info": { 408 "codemirror_mode": { 409 "name": "ipython", 410 "version": 3 411 }, 412 "file_extension": ".py", 413 "mimetype": "text/x-python", 414 "name": "python", 415 "nbconvert_exporter": "python", 416 "pygments_lexer": "ipython3", 417 "version": "3.9.21" 418 } 419 }, 420 "nbformat": 4, 421 "nbformat_minor": 0 422 }