outlines_llama-cpp-python_chain_of_thought.ipynb
1 { 2 "cells": [ 3 { 4 "cell_type": "markdown", 5 "id": "10d66530-b4e2-4b47-9413-92240069c5e1", 6 "metadata": {}, 7 "source": [ 8 "# Chain of Thought (CoT)" 9 ] 10 }, 11 { 12 "cell_type": "markdown", 13 "id": "2652e68d-4278-4e8f-9132-6ba90905d073", 14 "metadata": {}, 15 "source": [ 16 "Chain of thought is a prompting technique introduced in the paper [\"Chain-of-Thought Prompting Elicits Reasoning in Large Language Models\"](https://arxiv.org/abs/2201.11903) where throught prompting the authors generate a series of intermediate reasoning steps which improves the ability of LLMs to perform complex reasoning.\n", 17 "\n", 18 "In this guide, we use [outlines](https://outlines-dev.github.io/outlines/) to apply chain of thought through structured output with the quantized `Hermes-2-Pro-Llama-3-8B`." 19 ] 20 }, 21 { 22 "cell_type": "markdown", 23 "id": "14f30dda-18f6-415e-ac23-e87aeb636f83", 24 "metadata": {}, 25 "source": [ 26 "## Requirements\n", 27 "\n", 28 "### Install llama-cpp-python and outlines" 29 ] 30 }, 31 { 32 "cell_type": "code", 33 "execution_count": 1, 34 "id": "467c6655-4dd6-4f4e-ad9b-42fc1de0f52c", 35 "metadata": { 36 "execution": { 37 "iopub.execute_input": "2024-08-08T16:26:05.655426Z", 38 "iopub.status.busy": "2024-08-08T16:26:05.654799Z", 39 "iopub.status.idle": "2024-08-08T16:26:05.661046Z", 40 "shell.execute_reply": "2024-08-08T16:26:05.659644Z", 41 "shell.execute_reply.started": "2024-08-08T16:26:05.655375Z" 42 } 43 }, 44 "outputs": [], 45 "source": [ 46 "# RUN IT ONLY ONCE TO INSTALL THE REQUIREMENTS\n", 47 "# %pip install llama-cpp-python outlines" 48 ] 49 }, 50 { 51 "cell_type": "markdown", 52 "id": "3da62069-96d7-43e5-8968-64b48bc1384b", 53 "metadata": {}, 54 "source": [ 55 "For detailed installation instructions, see [llama-cpp-python installation](https://llama-cpp-python.readthedocs.io/en/stable/) and [outlines installation](https://outlines-dev.github.io/outlines/installation/)\n", 56 "\n", 57 "### Pull the model from HuggingFace\n", 58 "\n", 59 "Download a GGUF model from HuggingFace [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF/tree/main), for example, the Q4_K_M one (it requires 4.92 GB):" 60 ] 61 }, 62 { 63 "cell_type": "code", 64 "execution_count": 2, 65 "id": "ba6e9b01-0ad9-4f40-ac99-2e9340c1d3b1", 66 "metadata": { 67 "execution": { 68 "iopub.execute_input": "2024-08-08T16:26:05.662881Z", 69 "iopub.status.busy": "2024-08-08T16:26:05.662450Z", 70 "iopub.status.idle": "2024-08-08T16:26:05.686433Z", 71 "shell.execute_reply": "2024-08-08T16:26:05.685275Z", 72 "shell.execute_reply.started": "2024-08-08T16:26:05.662839Z" 73 } 74 }, 75 "outputs": [], 76 "source": [ 77 "# RUN IT ONLY ONCE TO DOWNLOAD THE GGUF MODEL, IN THIS CASE THE Q4_K_M\n", 78 "# !wget https://hf.co/NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF/resolve/main/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf" 79 ] 80 }, 81 { 82 "cell_type": "markdown", 83 "id": "467fa3ae-28bf-4636-9f23-92b3204df17d", 84 "metadata": {}, 85 "source": [ 86 "## Usage\n", 87 "\n", 88 "### Chain of Thought\n", 89 "\n", 90 "### Define Pydantic class\n", 91 "\n", 92 "We first define our Pydantic class for a reasoning step" 93 ] 94 }, 95 { 96 "cell_type": "code", 97 "execution_count": 3, 98 "id": "96cdbccc-c584-4966-a442-02741a171ab2", 99 "metadata": { 100 "execution": { 101 "iopub.execute_input": "2024-08-08T16:26:05.688514Z", 102 "iopub.status.busy": "2024-08-08T16:26:05.688043Z", 103 "iopub.status.idle": "2024-08-08T16:26:05.813020Z", 104 "shell.execute_reply": "2024-08-08T16:26:05.811943Z", 105 "shell.execute_reply.started": "2024-08-08T16:26:05.688469Z" 106 } 107 }, 108 "outputs": [], 109 "source": [ 110 "from pydantic import BaseModel, Field\n", 111 "\n", 112 "class Reasoning_Step(BaseModel):\n", 113 " reasoning_step: str = Field(..., description=\"Reasoning step\")" 114 ] 115 }, 116 { 117 "cell_type": "markdown", 118 "id": "1691514f-094c-4b5f-8065-71caddd23ca7", 119 "metadata": {}, 120 "source": [ 121 "We then define the Pydantic class for reasoning which will consist of a list of reasoning steps and a conclusion" 122 ] 123 }, 124 { 125 "cell_type": "code", 126 "execution_count": 4, 127 "id": "532995c5-f076-4bf7-b294-ed70332c1c10", 128 "metadata": { 129 "execution": { 130 "iopub.execute_input": "2024-08-08T16:26:05.814802Z", 131 "iopub.status.busy": "2024-08-08T16:26:05.814393Z", 132 "iopub.status.idle": "2024-08-08T16:26:05.823195Z", 133 "shell.execute_reply": "2024-08-08T16:26:05.822209Z", 134 "shell.execute_reply.started": "2024-08-08T16:26:05.814762Z" 135 } 136 }, 137 "outputs": [], 138 "source": [ 139 "from typing import List\n", 140 "\n", 141 "class Reasoning(BaseModel):\n", 142 " reasoning: List[Reasoning_Step] = Field(..., description=\"List of reasoning steps\")\n", 143 " conclusion: str = Field(..., description=\"Conclusion\")" 144 ] 145 }, 146 { 147 "cell_type": "markdown", 148 "id": "83358efd-2411-4383-962e-109b9d8afcc8", 149 "metadata": {}, 150 "source": [ 151 "### Load the model" 152 ] 153 }, 154 { 155 "cell_type": "code", 156 "execution_count": 5, 157 "id": "64f2fe73-56f7-4533-9357-394d5c6555dd", 158 "metadata": { 159 "execution": { 160 "iopub.execute_input": "2024-08-08T16:26:05.825224Z", 161 "iopub.status.busy": "2024-08-08T16:26:05.824745Z", 162 "iopub.status.idle": "2024-08-08T16:26:14.517480Z", 163 "shell.execute_reply": "2024-08-08T16:26:14.516116Z", 164 "shell.execute_reply.started": "2024-08-08T16:26:05.825179Z" 165 } 166 }, 167 "outputs": [ 168 { 169 "name": "stderr", 170 "output_type": "stream", 171 "text": [ 172 "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" 173 ] 174 } 175 ], 176 "source": [ 177 "import llama_cpp\n", 178 "from llama_cpp import Llama\n", 179 "from outlines import generate, models\n", 180 "\n", 181 "llm = Llama(\n", 182 " \"/big_storage/llms/models/Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf\",\n", 183 " tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(\n", 184 " \"NousResearch/Hermes-2-Pro-Llama-3-8B\"\n", 185 " ),\n", 186 " n_gpu_layers=-1,\n", 187 " flash_attn=True,\n", 188 " n_ctx=8192,\n", 189 " verbose=False\n", 190 ")\n", 191 "\n", 192 "model = models.LlamaCpp(llm)" 193 ] 194 }, 195 { 196 "cell_type": "code", 197 "execution_count": 6, 198 "id": "b33f0a08-a699-4682-a50a-e5b21acb7645", 199 "metadata": { 200 "execution": { 201 "iopub.execute_input": "2024-08-08T16:26:14.519387Z", 202 "iopub.status.busy": "2024-08-08T16:26:14.519194Z", 203 "iopub.status.idle": "2024-08-08T16:26:14.522935Z", 204 "shell.execute_reply": "2024-08-08T16:26:14.522297Z", 205 "shell.execute_reply.started": "2024-08-08T16:26:14.519372Z" 206 } 207 }, 208 "outputs": [], 209 "source": [ 210 "import warnings\n", 211 "warnings.filterwarnings(\"ignore\", category=RuntimeWarning) # ignore runtime warnings" 212 ] 213 }, 214 { 215 "cell_type": "markdown", 216 "id": "91149a2f-0d15-4d8f-8827-73a15a38464f", 217 "metadata": {}, 218 "source": [ 219 "We build a regex from the `Reasoning` Pydantic class which the model will be forced to follow" 220 ] 221 }, 222 { 223 "cell_type": "code", 224 "execution_count": 7, 225 "id": "a125eb20-efb2-4274-a42d-a35f58d9db54", 226 "metadata": { 227 "execution": { 228 "iopub.execute_input": "2024-08-08T16:26:14.523680Z", 229 "iopub.status.busy": "2024-08-08T16:26:14.523509Z", 230 "iopub.status.idle": "2024-08-08T16:26:14.563053Z", 231 "shell.execute_reply": "2024-08-08T16:26:14.562077Z", 232 "shell.execute_reply.started": "2024-08-08T16:26:14.523666Z" 233 } 234 }, 235 "outputs": [ 236 { 237 "data": { 238 "text/plain": [ 239 "'\\\\{ \"reasoning\" : \\\\[ ((\\\\{ \"reasoning_step\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\})(, (\\\\{ \"reasoning_step\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\})){0,})? \\\\] , \"conclusion\" : \"([^\"\\\\\\\\\\\\x00-\\\\x1F\\\\x7F-\\\\x9F]|\\\\\\\\[\"\\\\\\\\])*\" \\\\}'" 240 ] 241 }, 242 "execution_count": 7, 243 "metadata": {}, 244 "output_type": "execute_result" 245 } 246 ], 247 "source": [ 248 "from outlines.integrations.utils import convert_json_schema_to_str\n", 249 "from outlines.fsm.json_schema import build_regex_from_schema\n", 250 "\n", 251 "json_schema = Reasoning.model_json_schema()\n", 252 "schema_str = convert_json_schema_to_str(json_schema=json_schema)\n", 253 "regex_str = build_regex_from_schema(schema_str, whitespace_pattern=r\" \")\n", 254 "regex_str" 255 ] 256 }, 257 { 258 "cell_type": "markdown", 259 "id": "a853e5b6-1a40-43aa-a03d-4d5ece02b9c7", 260 "metadata": {}, 261 "source": [ 262 "We then need to adapt our prompt to the [Hermes prompt format for JSON schema](https://github.com/NousResearch/Hermes-Function-Calling?tab=readme-ov-file#prompt-format-for-json-mode--structured-outputs)" 263 ] 264 }, 265 { 266 "cell_type": "code", 267 "execution_count": 8, 268 "id": "3fddb1a4-fc3d-4a81-bfb6-ab07c94c16e2", 269 "metadata": { 270 "execution": { 271 "iopub.execute_input": "2024-08-08T16:26:14.564898Z", 272 "iopub.status.busy": "2024-08-08T16:26:14.564446Z", 273 "iopub.status.idle": "2024-08-08T16:26:14.570846Z", 274 "shell.execute_reply": "2024-08-08T16:26:14.569634Z", 275 "shell.execute_reply.started": "2024-08-08T16:26:14.564854Z" 276 } 277 }, 278 "outputs": [], 279 "source": [ 280 "def generate_hermes_prompt(question):\n", 281 " return (\n", 282 " \"<|im_start|>system\\n\"\n", 283 " \"You are a world class AI model who answers questions in JSON with correct Pydantic schema. \"\n", 284 " \"Here's the json schema you must adhere to:\\n<schema>\\n\" + str(json_schema) + \"\\n</schema>\"\n", 285 " \"\\n<|im_start|>user\\n\" + question + \"<|im_end|>\"\n", 286 " \"\\n<|im_start|>assistant\\n\"\n", 287 " )" 288 ] 289 }, 290 { 291 "cell_type": "markdown", 292 "id": "133966dc-642a-4c64-983a-66f0ece78b2b", 293 "metadata": {}, 294 "source": [ 295 "For a given `user_prompt` we obtain the hermes prompt" 296 ] 297 }, 298 { 299 "cell_type": "code", 300 "execution_count": 9, 301 "id": "939e5c40-81ce-4a5a-b253-480e1a0bb1b6", 302 "metadata": { 303 "execution": { 304 "iopub.execute_input": "2024-08-08T16:26:14.572631Z", 305 "iopub.status.busy": "2024-08-08T16:26:14.572211Z", 306 "iopub.status.idle": "2024-08-08T16:26:14.588481Z", 307 "shell.execute_reply": "2024-08-08T16:26:14.587336Z", 308 "shell.execute_reply.started": "2024-08-08T16:26:14.572591Z" 309 } 310 }, 311 "outputs": [ 312 { 313 "name": "stdout", 314 "output_type": "stream", 315 "text": [ 316 "<|im_start|>system\n", 317 "You are a world class AI model who answers questions in JSON with correct Pydantic schema. Here's the json schema you must adhere to:\n", 318 "<schema>\n", 319 "{'$defs': {'Reasoning_Step': {'properties': {'reasoning_step': {'description': 'Reasoning step', 'title': 'Reasoning Step', 'type': 'string'}}, 'required': ['reasoning_step'], 'title': 'Reasoning_Step', 'type': 'object'}}, 'properties': {'reasoning': {'description': 'List of reasoning steps', 'items': {'$ref': '#/$defs/Reasoning_Step'}, 'title': 'Reasoning', 'type': 'array'}, 'conclusion': {'description': 'Conclusion', 'title': 'Conclusion', 'type': 'string'}}, 'required': ['reasoning', 'conclusion'], 'title': 'Reasoning', 'type': 'object'}\n", 320 "</schema>\n", 321 "<|im_start|>user\n", 322 "9.11 and 9.9 -- which is bigger?<|im_end|>\n", 323 "<|im_start|>assistant\n", 324 "\n" 325 ] 326 } 327 ], 328 "source": [ 329 "user_prompt = \"9.11 and 9.9 -- which is bigger?\"\n", 330 "prompt = generate_hermes_prompt(user_prompt)\n", 331 "print(prompt)" 332 ] 333 }, 334 { 335 "cell_type": "markdown", 336 "id": "a0608a24-7f39-4ed8-9b6b-d551d9d556a6", 337 "metadata": {}, 338 "source": [ 339 "We use `generate.regex` by passing the Pydantic class we previously defined, and call the generator with the Hermes prompt:" 340 ] 341 }, 342 { 343 "cell_type": "code", 344 "execution_count": 10, 345 "id": "97e39197-cf88-4008-8e6d-814dec90a9a8", 346 "metadata": { 347 "execution": { 348 "iopub.execute_input": "2024-08-08T16:26:14.590357Z", 349 "iopub.status.busy": "2024-08-08T16:26:14.589904Z", 350 "iopub.status.idle": "2024-08-08T16:26:18.016969Z", 351 "shell.execute_reply": "2024-08-08T16:26:18.015891Z", 352 "shell.execute_reply.started": "2024-08-08T16:26:14.590313Z" 353 } 354 }, 355 "outputs": [ 356 { 357 "data": { 358 "text/plain": [ 359 "'{ \"reasoning\" : [ { \"reasoning_step\" : \"Both 9.11 and 9.9 are decimal numbers.\" }, { \"reasoning_step\" : \"When comparing decimal numbers, we look at the numbers after the decimal point.\" }, { \"reasoning_step\" : \"In this case, 9.11 has the number 1 after the decimal point, while 9.9 has the number 9.\" }, { \"reasoning_step\" : \"Since 1 is greater than 9, 9.11 is greater than 9.9.\" } ], \"conclusion\" : \"9.11 is bigger.\" }'" 360 ] 361 }, 362 "execution_count": 10, 363 "metadata": {}, 364 "output_type": "execute_result" 365 } 366 ], 367 "source": [ 368 "generator = generate.regex(model, regex_str)\n", 369 "response = generator(prompt, max_tokens=1024, temperature=0, seed=42)\n", 370 "response" 371 ] 372 }, 373 { 374 "cell_type": "markdown", 375 "id": "7d1f9afb-83bf-4b48-be71-23e80b83f1e8", 376 "metadata": {}, 377 "source": [ 378 "We obtain a series of intermediate reasoning steps as well as the conclusion" 379 ] 380 }, 381 { 382 "cell_type": "code", 383 "execution_count": 11, 384 "id": "8d0062e0-e49f-4034-8606-a6491f8fd154", 385 "metadata": { 386 "execution": { 387 "iopub.execute_input": "2024-08-08T16:26:18.018100Z", 388 "iopub.status.busy": "2024-08-08T16:26:18.017899Z", 389 "iopub.status.idle": "2024-08-08T16:26:18.023265Z", 390 "shell.execute_reply": "2024-08-08T16:26:18.022596Z", 391 "shell.execute_reply.started": "2024-08-08T16:26:18.018083Z" 392 } 393 }, 394 "outputs": [ 395 { 396 "data": { 397 "text/plain": [ 398 "[{'reasoning_step': 'Both 9.11 and 9.9 are decimal numbers.'},\n", 399 " {'reasoning_step': 'When comparing decimal numbers, we look at the numbers after the decimal point.'},\n", 400 " {'reasoning_step': 'In this case, 9.11 has the number 1 after the decimal point, while 9.9 has the number 9.'},\n", 401 " {'reasoning_step': 'Since 1 is greater than 9, 9.11 is greater than 9.9.'}]" 402 ] 403 }, 404 "execution_count": 11, 405 "metadata": {}, 406 "output_type": "execute_result" 407 } 408 ], 409 "source": [ 410 "import json\n", 411 "\n", 412 "json_response = json.loads(response)\n", 413 "json_response[\"reasoning\"]" 414 ] 415 }, 416 { 417 "cell_type": "code", 418 "execution_count": 12, 419 "id": "efa12b2b-800d-4bc6-9a07-7847cfd2842f", 420 "metadata": { 421 "execution": { 422 "iopub.execute_input": "2024-08-08T16:26:18.024347Z", 423 "iopub.status.busy": "2024-08-08T16:26:18.024047Z", 424 "iopub.status.idle": "2024-08-08T16:26:18.049441Z", 425 "shell.execute_reply": "2024-08-08T16:26:18.048426Z", 426 "shell.execute_reply.started": "2024-08-08T16:26:18.024323Z" 427 } 428 }, 429 "outputs": [ 430 { 431 "data": { 432 "text/plain": [ 433 "'9.11 is bigger.'" 434 ] 435 }, 436 "execution_count": 12, 437 "metadata": {}, 438 "output_type": "execute_result" 439 } 440 ], 441 "source": [ 442 "json_response[\"conclusion\"]" 443 ] 444 }, 445 { 446 "cell_type": "markdown", 447 "id": "90363279-bce2-4ffd-aa38-e582e8523014", 448 "metadata": {}, 449 "source": [ 450 "We notice that the 4th reasoning step is wrong `Since 1 is greater than 9, 9.11 is greater than 9.9.`, so we should probably give the model some examples for this particular task." 451 ] 452 } 453 ], 454 "metadata": { 455 "kernelspec": { 456 "display_name": "Python 3 (ipykernel)", 457 "language": "python", 458 "name": "python3" 459 }, 460 "language_info": { 461 "codemirror_mode": { 462 "name": "ipython", 463 "version": 3 464 }, 465 "file_extension": ".py", 466 "mimetype": "text/x-python", 467 "name": "python", 468 "nbconvert_exporter": "python", 469 "pygments_lexer": "ipython3", 470 "version": "3.10.12" 471 } 472 }, 473 "nbformat": 4, 474 "nbformat_minor": 5 475 }