/ examples / outlines_llama-cpp-python_chain_of_thought.ipynb
outlines_llama-cpp-python_chain_of_thought.ipynb
  1  {
  2   "cells": [
  3    {
  4     "cell_type": "markdown",
  5     "id": "10d66530-b4e2-4b47-9413-92240069c5e1",
  6     "metadata": {},
  7     "source": [
  8      "# Chain of Thought (CoT)"
  9     ]
 10    },
 11    {
 12     "cell_type": "markdown",
 13     "id": "2652e68d-4278-4e8f-9132-6ba90905d073",
 14     "metadata": {},
 15     "source": [
 16      "Chain of thought is a prompting technique introduced in the paper [\"Chain-of-Thought Prompting Elicits Reasoning in Large Language Models\"](https://arxiv.org/abs/2201.11903) where throught prompting the authors generate a series of intermediate reasoning steps which improves the ability of LLMs to perform complex reasoning.\n",
 17      "\n",
 18      "In this guide, we use [outlines](https://outlines-dev.github.io/outlines/) to apply chain of thought through structured output with the quantized `Hermes-2-Pro-Llama-3-8B`."
 19     ]
 20    },
 21    {
 22     "cell_type": "markdown",
 23     "id": "14f30dda-18f6-415e-ac23-e87aeb636f83",
 24     "metadata": {},
 25     "source": [
 26      "## Requirements\n",
 27      "\n",
 28      "### Install llama-cpp-python and outlines"
 29     ]
 30    },
 31    {
 32     "cell_type": "code",
 33     "execution_count": 1,
 34     "id": "467c6655-4dd6-4f4e-ad9b-42fc1de0f52c",
 35     "metadata": {
 36      "execution": {
 37       "iopub.execute_input": "2024-08-08T16:26:05.655426Z",
 38       "iopub.status.busy": "2024-08-08T16:26:05.654799Z",
 39       "iopub.status.idle": "2024-08-08T16:26:05.661046Z",
 40       "shell.execute_reply": "2024-08-08T16:26:05.659644Z",
 41       "shell.execute_reply.started": "2024-08-08T16:26:05.655375Z"
 42      }
 43     },
 44     "outputs": [],
 45     "source": [
 46      "# RUN IT ONLY ONCE TO INSTALL THE REQUIREMENTS\n",
 47      "# %pip install llama-cpp-python outlines"
 48     ]
 49    },
 50    {
 51     "cell_type": "markdown",
 52     "id": "3da62069-96d7-43e5-8968-64b48bc1384b",
 53     "metadata": {},
 54     "source": [
 55      "For detailed installation instructions, see [llama-cpp-python installation](https://llama-cpp-python.readthedocs.io/en/stable/) and [outlines installation](https://outlines-dev.github.io/outlines/installation/)\n",
 56      "\n",
 57      "### Pull the model from HuggingFace\n",
 58      "\n",
 59      "Download a GGUF model from HuggingFace [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF/tree/main), for example, the Q4_K_M one (it requires 4.92 GB):"
 60     ]
 61    },
 62    {
 63     "cell_type": "code",
 64     "execution_count": 2,
 65     "id": "ba6e9b01-0ad9-4f40-ac99-2e9340c1d3b1",
 66     "metadata": {
 67      "execution": {
 68       "iopub.execute_input": "2024-08-08T16:26:05.662881Z",
 69       "iopub.status.busy": "2024-08-08T16:26:05.662450Z",
 70       "iopub.status.idle": "2024-08-08T16:26:05.686433Z",
 71       "shell.execute_reply": "2024-08-08T16:26:05.685275Z",
 72       "shell.execute_reply.started": "2024-08-08T16:26:05.662839Z"
 73      }
 74     },
 75     "outputs": [],
 76     "source": [
 77      "# RUN IT ONLY ONCE TO DOWNLOAD THE GGUF MODEL, IN THIS CASE THE Q4_K_M\n",
 78      "# !wget https://hf.co/NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF/resolve/main/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf"
 79     ]
 80    },
 81    {
 82     "cell_type": "markdown",
 83     "id": "467fa3ae-28bf-4636-9f23-92b3204df17d",
 84     "metadata": {},
 85     "source": [
 86      "## Usage\n",
 87      "\n",
 88      "### Chain of Thought\n",
 89      "\n",
 90      "### Define Pydantic class\n",
 91      "\n",
 92      "We first define our Pydantic class for a reasoning step"
 93     ]
 94    },
 95    {
 96     "cell_type": "code",
 97     "execution_count": 3,
 98     "id": "96cdbccc-c584-4966-a442-02741a171ab2",
 99     "metadata": {
100      "execution": {
101       "iopub.execute_input": "2024-08-08T16:26:05.688514Z",
102       "iopub.status.busy": "2024-08-08T16:26:05.688043Z",
103       "iopub.status.idle": "2024-08-08T16:26:05.813020Z",
104       "shell.execute_reply": "2024-08-08T16:26:05.811943Z",
105       "shell.execute_reply.started": "2024-08-08T16:26:05.688469Z"
106      }
107     },
108     "outputs": [],
109     "source": [
110      "from pydantic import BaseModel, Field\n",
111      "\n",
112      "class Reasoning_Step(BaseModel):\n",
113      "    reasoning_step: str = Field(..., description=\"Reasoning step\")"
114     ]
115    },
116    {
117     "cell_type": "markdown",
118     "id": "1691514f-094c-4b5f-8065-71caddd23ca7",
119     "metadata": {},
120     "source": [
121      "We then define the Pydantic class for reasoning which will consist of a list of reasoning steps and a conclusion"
122     ]
123    },
124    {
125     "cell_type": "code",
126     "execution_count": 4,
127     "id": "532995c5-f076-4bf7-b294-ed70332c1c10",
128     "metadata": {
129      "execution": {
130       "iopub.execute_input": "2024-08-08T16:26:05.814802Z",
131       "iopub.status.busy": "2024-08-08T16:26:05.814393Z",
132       "iopub.status.idle": "2024-08-08T16:26:05.823195Z",
133       "shell.execute_reply": "2024-08-08T16:26:05.822209Z",
134       "shell.execute_reply.started": "2024-08-08T16:26:05.814762Z"
135      }
136     },
137     "outputs": [],
138     "source": [
139      "from typing import List\n",
140      "\n",
141      "class Reasoning(BaseModel):\n",
142      "    reasoning: List[Reasoning_Step] = Field(..., description=\"List of reasoning steps\")\n",
143      "    conclusion: str = Field(..., description=\"Conclusion\")"
144     ]
145    },
146    {
147     "cell_type": "markdown",
148     "id": "83358efd-2411-4383-962e-109b9d8afcc8",
149     "metadata": {},
150     "source": [
151      "### Load the model"
152     ]
153    },
154    {
155     "cell_type": "code",
156     "execution_count": 5,
157     "id": "64f2fe73-56f7-4533-9357-394d5c6555dd",
158     "metadata": {
159      "execution": {
160       "iopub.execute_input": "2024-08-08T16:26:05.825224Z",
161       "iopub.status.busy": "2024-08-08T16:26:05.824745Z",
162       "iopub.status.idle": "2024-08-08T16:26:14.517480Z",
163       "shell.execute_reply": "2024-08-08T16:26:14.516116Z",
164       "shell.execute_reply.started": "2024-08-08T16:26:05.825179Z"
165      }
166     },
167     "outputs": [
168      {
169       "name": "stderr",
170       "output_type": "stream",
171       "text": [
172        "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
173       ]
174      }
175     ],
176     "source": [
177      "import llama_cpp\n",
178      "from llama_cpp import Llama\n",
179      "from outlines import generate, models\n",
180      "\n",
181      "llm = Llama(\n",
182      "    \"/big_storage/llms/models/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf\",\n",
183      "    tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(\n",
184      "        \"NousResearch/Hermes-2-Pro-Llama-3-8B\"\n",
185      "    ),\n",
186      "    n_gpu_layers=-1,\n",
187      "    flash_attn=True,\n",
188      "    n_ctx=8192,\n",
189      "    verbose=False\n",
190      ")\n",
191      "\n",
192      "model = models.LlamaCpp(llm)"
193     ]
194    },
195    {
196     "cell_type": "code",
197     "execution_count": 6,
198     "id": "b33f0a08-a699-4682-a50a-e5b21acb7645",
199     "metadata": {
200      "execution": {
201       "iopub.execute_input": "2024-08-08T16:26:14.519387Z",
202       "iopub.status.busy": "2024-08-08T16:26:14.519194Z",
203       "iopub.status.idle": "2024-08-08T16:26:14.522935Z",
204       "shell.execute_reply": "2024-08-08T16:26:14.522297Z",
205       "shell.execute_reply.started": "2024-08-08T16:26:14.519372Z"
206      }
207     },
208     "outputs": [],
209     "source": [
210      "import warnings\n",
211      "warnings.filterwarnings(\"ignore\", category=RuntimeWarning) # ignore runtime warnings"
212     ]
213    },
214    {
215     "cell_type": "markdown",
216     "id": "91149a2f-0d15-4d8f-8827-73a15a38464f",
217     "metadata": {},
218     "source": [
219      "We build a regex from the `Reasoning` Pydantic class which the model will be forced to follow"
220     ]
221    },
222    {
223     "cell_type": "code",
224     "execution_count": 7,
225     "id": "a125eb20-efb2-4274-a42d-a35f58d9db54",
226     "metadata": {
227      "execution": {
228       "iopub.execute_input": "2024-08-08T16:26:14.523680Z",
229       "iopub.status.busy": "2024-08-08T16:26:14.523509Z",
230       "iopub.status.idle": "2024-08-08T16:26:14.563053Z",
231       "shell.execute_reply": "2024-08-08T16:26:14.562077Z",
232       "shell.execute_reply.started": "2024-08-08T16:26:14.523666Z"
233      }
234     },
235     "outputs": [
236      {
237       "data": {
238        "text/plain": [
239         "'\\\\{ \"reasoning\" : \\\\[ ((\\\\{ \"reasoning_step\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\})(, (\\\\{ \"reasoning_step\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\})){0,})? \\\\] , \"conclusion\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\}'"
240        ]
241       },
242       "execution_count": 7,
243       "metadata": {},
244       "output_type": "execute_result"
245      }
246     ],
247     "source": [
248      "from outlines.integrations.utils import convert_json_schema_to_str\n",
249      "from outlines.fsm.json_schema import build_regex_from_schema\n",
250      "\n",
251      "json_schema = Reasoning.model_json_schema()\n",
252      "schema_str = convert_json_schema_to_str(json_schema=json_schema)\n",
253      "regex_str = build_regex_from_schema(schema_str, whitespace_pattern=r\" \")\n",
254      "regex_str"
255     ]
256    },
257    {
258     "cell_type": "markdown",
259     "id": "a853e5b6-1a40-43aa-a03d-4d5ece02b9c7",
260     "metadata": {},
261     "source": [
262      "We then need to adapt our prompt to the [Hermes prompt format for JSON schema](https://github.com/NousResearch/Hermes-Function-Calling?tab=readme-ov-file#prompt-format-for-json-mode--structured-outputs)"
263     ]
264    },
265    {
266     "cell_type": "code",
267     "execution_count": 8,
268     "id": "3fddb1a4-fc3d-4a81-bfb6-ab07c94c16e2",
269     "metadata": {
270      "execution": {
271       "iopub.execute_input": "2024-08-08T16:26:14.564898Z",
272       "iopub.status.busy": "2024-08-08T16:26:14.564446Z",
273       "iopub.status.idle": "2024-08-08T16:26:14.570846Z",
274       "shell.execute_reply": "2024-08-08T16:26:14.569634Z",
275       "shell.execute_reply.started": "2024-08-08T16:26:14.564854Z"
276      }
277     },
278     "outputs": [],
279     "source": [
280      "def generate_hermes_prompt(question):\n",
281      "    return (\n",
282      "        \"<|im_start|>system\\n\"\n",
283      "        \"You are a world class AI model who answers questions in JSON with correct Pydantic schema. \"\n",
284      "        \"Here's the json schema you must adhere to:\\n<schema>\\n\" + str(json_schema) + \"\\n</schema>\"\n",
285      "        \"\\n<|im_start|>user\\n\" + question + \"<|im_end|>\"\n",
286      "        \"\\n<|im_start|>assistant\\n\"\n",
287      "    )"
288     ]
289    },
290    {
291     "cell_type": "markdown",
292     "id": "133966dc-642a-4c64-983a-66f0ece78b2b",
293     "metadata": {},
294     "source": [
295      "For a given `user_prompt` we obtain the hermes prompt"
296     ]
297    },
298    {
299     "cell_type": "code",
300     "execution_count": 9,
301     "id": "939e5c40-81ce-4a5a-b253-480e1a0bb1b6",
302     "metadata": {
303      "execution": {
304       "iopub.execute_input": "2024-08-08T16:26:14.572631Z",
305       "iopub.status.busy": "2024-08-08T16:26:14.572211Z",
306       "iopub.status.idle": "2024-08-08T16:26:14.588481Z",
307       "shell.execute_reply": "2024-08-08T16:26:14.587336Z",
308       "shell.execute_reply.started": "2024-08-08T16:26:14.572591Z"
309      }
310     },
311     "outputs": [
312      {
313       "name": "stdout",
314       "output_type": "stream",
315       "text": [
316        "<|im_start|>system\n",
317        "You are a world class AI model who answers questions in JSON with correct Pydantic schema. Here's the json schema you must adhere to:\n",
318        "<schema>\n",
319        "{'$defs': {'Reasoning_Step': {'properties': {'reasoning_step': {'description': 'Reasoning step', 'title': 'Reasoning Step', 'type': 'string'}}, 'required': ['reasoning_step'], 'title': 'Reasoning_Step', 'type': 'object'}}, 'properties': {'reasoning': {'description': 'List of reasoning steps', 'items': {'$ref': '#/$defs/Reasoning_Step'}, 'title': 'Reasoning', 'type': 'array'}, 'conclusion': {'description': 'Conclusion', 'title': 'Conclusion', 'type': 'string'}}, 'required': ['reasoning', 'conclusion'], 'title': 'Reasoning', 'type': 'object'}\n",
320        "</schema>\n",
321        "<|im_start|>user\n",
322        "9.11 and 9.9 -- which is bigger?<|im_end|>\n",
323        "<|im_start|>assistant\n",
324        "\n"
325       ]
326      }
327     ],
328     "source": [
329      "user_prompt = \"9.11 and 9.9 -- which is bigger?\"\n",
330      "prompt = generate_hermes_prompt(user_prompt)\n",
331      "print(prompt)"
332     ]
333    },
334    {
335     "cell_type": "markdown",
336     "id": "a0608a24-7f39-4ed8-9b6b-d551d9d556a6",
337     "metadata": {},
338     "source": [
339      "We use `generate.regex` by passing the Pydantic class we previously defined, and call the generator with the Hermes prompt:"
340     ]
341    },
342    {
343     "cell_type": "code",
344     "execution_count": 10,
345     "id": "97e39197-cf88-4008-8e6d-814dec90a9a8",
346     "metadata": {
347      "execution": {
348       "iopub.execute_input": "2024-08-08T16:26:14.590357Z",
349       "iopub.status.busy": "2024-08-08T16:26:14.589904Z",
350       "iopub.status.idle": "2024-08-08T16:26:18.016969Z",
351       "shell.execute_reply": "2024-08-08T16:26:18.015891Z",
352       "shell.execute_reply.started": "2024-08-08T16:26:14.590313Z"
353      }
354     },
355     "outputs": [
356      {
357       "data": {
358        "text/plain": [
359         "'{ \"reasoning\" : [ { \"reasoning_step\" : \"Both 9.11 and 9.9 are decimal numbers.\" }, { \"reasoning_step\" : \"When comparing decimal numbers, we look at the numbers after the decimal point.\" }, { \"reasoning_step\" : \"In this case, 9.11 has the number 1 after the decimal point, while 9.9 has the number 9.\" }, { \"reasoning_step\" : \"Since 1 is greater than 9, 9.11 is greater than 9.9.\" } ], \"conclusion\" : \"9.11 is bigger.\" }'"
360        ]
361       },
362       "execution_count": 10,
363       "metadata": {},
364       "output_type": "execute_result"
365      }
366     ],
367     "source": [
368      "generator = generate.regex(model, regex_str)\n",
369      "response = generator(prompt, max_tokens=1024, temperature=0, seed=42)\n",
370      "response"
371     ]
372    },
373    {
374     "cell_type": "markdown",
375     "id": "7d1f9afb-83bf-4b48-be71-23e80b83f1e8",
376     "metadata": {},
377     "source": [
378      "We obtain a series of intermediate reasoning steps as well as the conclusion"
379     ]
380    },
381    {
382     "cell_type": "code",
383     "execution_count": 11,
384     "id": "8d0062e0-e49f-4034-8606-a6491f8fd154",
385     "metadata": {
386      "execution": {
387       "iopub.execute_input": "2024-08-08T16:26:18.018100Z",
388       "iopub.status.busy": "2024-08-08T16:26:18.017899Z",
389       "iopub.status.idle": "2024-08-08T16:26:18.023265Z",
390       "shell.execute_reply": "2024-08-08T16:26:18.022596Z",
391       "shell.execute_reply.started": "2024-08-08T16:26:18.018083Z"
392      }
393     },
394     "outputs": [
395      {
396       "data": {
397        "text/plain": [
398         "[{'reasoning_step': 'Both 9.11 and 9.9 are decimal numbers.'},\n",
399         " {'reasoning_step': 'When comparing decimal numbers, we look at the numbers after the decimal point.'},\n",
400         " {'reasoning_step': 'In this case, 9.11 has the number 1 after the decimal point, while 9.9 has the number 9.'},\n",
401         " {'reasoning_step': 'Since 1 is greater than 9, 9.11 is greater than 9.9.'}]"
402        ]
403       },
404       "execution_count": 11,
405       "metadata": {},
406       "output_type": "execute_result"
407      }
408     ],
409     "source": [
410      "import json\n",
411      "\n",
412      "json_response = json.loads(response)\n",
413      "json_response[\"reasoning\"]"
414     ]
415    },
416    {
417     "cell_type": "code",
418     "execution_count": 12,
419     "id": "efa12b2b-800d-4bc6-9a07-7847cfd2842f",
420     "metadata": {
421      "execution": {
422       "iopub.execute_input": "2024-08-08T16:26:18.024347Z",
423       "iopub.status.busy": "2024-08-08T16:26:18.024047Z",
424       "iopub.status.idle": "2024-08-08T16:26:18.049441Z",
425       "shell.execute_reply": "2024-08-08T16:26:18.048426Z",
426       "shell.execute_reply.started": "2024-08-08T16:26:18.024323Z"
427      }
428     },
429     "outputs": [
430      {
431       "data": {
432        "text/plain": [
433         "'9.11 is bigger.'"
434        ]
435       },
436       "execution_count": 12,
437       "metadata": {},
438       "output_type": "execute_result"
439      }
440     ],
441     "source": [
442      "json_response[\"conclusion\"]"
443     ]
444    },
445    {
446     "cell_type": "markdown",
447     "id": "90363279-bce2-4ffd-aa38-e582e8523014",
448     "metadata": {},
449     "source": [
450      "We notice that the 4th reasoning step is wrong `Since 1 is greater than 9, 9.11 is greater than 9.9.`, so we should probably give the model some examples for this particular task."
451     ]
452    }
453   ],
454   "metadata": {
455    "kernelspec": {
456     "display_name": "Python 3 (ipykernel)",
457     "language": "python",
458     "name": "python3"
459    },
460    "language_info": {
461     "codemirror_mode": {
462      "name": "ipython",
463      "version": 3
464     },
465     "file_extension": ".py",
466     "mimetype": "text/x-python",
467     "name": "python",
468     "nbconvert_exporter": "python",
469     "pygments_lexer": "ipython3",
470     "version": "3.10.12"
471    }
472   },
473   "nbformat": 4,
474   "nbformat_minor": 5
475  }