/ template_tests / hf_chat_template.ipynb
hf_chat_template.ipynb
1 { 2 "cells": [ 3 { 4 "cell_type": "code", 5 "execution_count": 2, 6 "metadata": {}, 7 "outputs": [ 8 { 9 "name": "stdout", 10 "output_type": "stream", 11 "text": [ 12 "\n", 13 "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.2\u001b[0m\n", 14 "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", 15 "Note: you may need to restart the kernel to use updated packages.\n", 16 "Collecting accelerate\n", 17 " Downloading accelerate-0.33.0-py3-none-any.whl.metadata (18 kB)\n", 18 "Requirement already satisfied: numpy<2.0.0,>=1.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (1.26.4)\n", 19 "Requirement already satisfied: packaging>=20.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (24.1)\n", 20 "Requirement already satisfied: psutil in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (6.0.0)\n", 21 "Requirement already satisfied: pyyaml in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (6.0.1)\n", 22 "Requirement already satisfied: torch>=1.10.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (2.2.1+cu121)\n", 23 "Requirement already satisfied: huggingface-hub>=0.21.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (0.24.5)\n", 24 "Requirement already satisfied: safetensors>=0.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (0.4.4)\n", 25 "Requirement already satisfied: filelock in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (3.15.4)\n", 26 "Requirement already satisfied: fsspec>=2023.5.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2024.6.1)\n", 27 "Requirement already satisfied: requests in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)\n", 28 "Requirement already satisfied: tqdm>=4.42.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.66.4)\n", 29 "Requirement already satisfied: typing-extensions>=3.7.4.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)\n", 30 "Requirement already satisfied: sympy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (1.13.0)\n", 31 "Requirement already satisfied: networkx in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.3)\n", 32 "Requirement already satisfied: jinja2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.4)\n", 33 "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.1.105)\n", 34 "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.1.105)\n", 35 "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.1.105)\n", 36 "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (8.9.2.26)\n", 37 "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.1.3.1)\n", 38 "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.0.2.54)\n", 39 "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (10.3.2.106)\n", 40 "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.4.5.107)\n", 41 "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.1.0.106)\n", 42 "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (2.19.3)\n", 43 "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (12.1.105)\n", 44 "Requirement already satisfied: triton==2.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (2.2.0)\n", 45 "Requirement already satisfied: nvidia-nvjitlink-cu12 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.10.0->accelerate) (12.5.82)\n", 46 "Requirement already satisfied: MarkupSafe>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\n", 47 "Requirement already satisfied: charset-normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.3.2)\n", 48 "Requirement already satisfied: idna<4,>=2.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.7)\n", 49 "Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2.2.2)\n", 50 "Requirement already satisfied: certifi>=2017.4.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.7.4)\n", 51 "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n", 52 "Downloading accelerate-0.33.0-py3-none-any.whl (315 kB)\n", 53 "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m315.1/315.1 kB\u001b[0m \u001b[31m25.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 54 "\u001b[?25hInstalling collected packages: accelerate\n", 55 "Successfully installed accelerate-0.33.0\n", 56 "\n", 57 "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.2\u001b[0m\n", 58 "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", 59 "Note: you may need to restart the kernel to use updated packages.\n" 60 ] 61 } 62 ], 63 "source": [ 64 "%pip install --upgrade --quiet transformers\n", 65 "%pip install accelerate" 66 ] 67 }, 68 { 69 "cell_type": "code", 70 "execution_count": 1, 71 "metadata": {}, 72 "outputs": [ 73 { 74 "data": { 75 "application/vnd.jupyter.widget-view+json": { 76 "model_id": "d78346ab141a4861bd3bd8cc48754be5", 77 "version_major": 2, 78 "version_minor": 0 79 }, 80 "text/plain": [ 81 "Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]" 82 ] 83 }, 84 "metadata": {}, 85 "output_type": "display_data" 86 }, 87 { 88 "name": "stderr", 89 "output_type": "stream", 90 "text": [ 91 "Some parameters are on the meta device device because they were offloaded to the cpu.\n" 92 ] 93 } 94 ], 95 "source": [ 96 "from transformers import AutoTokenizer, AutoModelForCausalLM\n", 97 "\n", 98 "model_path = \"NousResearch/Hermes-2-Pro-Llama-3-8B\"\n", 99 "tokenizer = AutoTokenizer.from_pretrained(model_path)\n", 100 "model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\")" 101 ] 102 }, 103 { 104 "cell_type": "code", 105 "execution_count": 2, 106 "metadata": {}, 107 "outputs": [ 108 { 109 "name": "stdout", 110 "output_type": "stream", 111 "text": [ 112 "{%- macro json_to_python_type(json_spec) %}\n", 113 "{%- set basic_type_map = {\n", 114 " \"string\": \"str\",\n", 115 " \"number\": \"float\",\n", 116 " \"integer\": \"int\",\n", 117 " \"boolean\": \"bool\"\n", 118 "} %}\n", 119 "\n", 120 "{%- if basic_type_map[json_spec.type] is defined %}\n", 121 " {{- basic_type_map[json_spec.type] }}\n", 122 "{%- elif json_spec.type == \"array\" %}\n", 123 " {{- \"list[\" + json_to_python_type(json_spec|items) + \"]\"}}\n", 124 "{%- elif json_spec.type == \"object\" %}\n", 125 " {%- if json_spec.additionalProperties is defined %}\n", 126 " {{- \"dict[str, \" + json_to_python_type(json_spec.additionalProperties) + ']'}}\n", 127 " {%- else %}\n", 128 " {{- \"dict\" }}\n", 129 " {%- endif %}\n", 130 "{%- elif json_spec.type is iterable %}\n", 131 " {{- \"Union[\" }}\n", 132 " {%- for t in json_spec.type %}\n", 133 " {{- json_to_python_type({\"type\": t}) }}\n", 134 " {%- if not loop.last %}\n", 135 " {{- \",\" }} \n", 136 " {%- endif %}\n", 137 " {%- endfor %}\n", 138 " {{- \"]\" }}\n", 139 "{%- else %}\n", 140 " {{- \"Any\" }}\n", 141 "{%- endif %}\n", 142 "{%- endmacro %}\n", 143 "\n", 144 "\n", 145 "{{- bos_token }}\n", 146 "{{- \"You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> \" }}\n", 147 "{%- for tool in tools %}\n", 148 " {%- if tool.function is defined %}\n", 149 " {%- set tool = tool.function %}\n", 150 " {%- endif %}\n", 151 " {{- '{\"type\": \"function\", \"function\": ' }}\n", 152 " {{- '{\"name\": \"' + tool.name + '\", ' }}\n", 153 " {{- '\"description\": \"' + tool.name + '(' }}\n", 154 " {%- for param_name, param_fields in tool.parameters.properties|items %}\n", 155 " {{- param_name + \": \" + json_to_python_type(param_fields) }}\n", 156 " {%- if not loop.last %}\n", 157 " {{- \", \" }}\n", 158 " {%- endif %}\n", 159 " {%- endfor %}\n", 160 " {{- \")\" }}\n", 161 " {%- if tool.return is defined %}\n", 162 " {{- \" -> \" + json_to_python_type(tool.return) }}\n", 163 " {%- endif %}\n", 164 " {{- \" - \" + tool.description + \"\\n\\n\" }}\n", 165 " {%- for param_name, param_fields in tool.parameters.properties|items %}\n", 166 " {%- if loop.first %}\n", 167 " {{- \" Args:\\n\" }}\n", 168 " {%- endif %}\n", 169 " {{- \" \" + param_name + \"(\" + json_to_python_type(param_fields) + \"): \" + param_fields.description|trim }}\n", 170 " {%- endfor %}\n", 171 " {%- if tool.return is defined and tool.return.description is defined %}\n", 172 " {{- \"\\n Returns:\\n \" + tool.return.description }}\n", 173 " {%- endif %}\n", 174 " {{- '\"' }}\n", 175 " {{- ', \"parameters\": ' }}\n", 176 " {%- if tool.parameters.properties | length == 0 %}\n", 177 " {{- \"{}\" }}\n", 178 " {%- else %}\n", 179 " {{- tool.parameters|tojson }}\n", 180 " {%- endif %}\n", 181 " {{- \"}\" }}\n", 182 " {%- if not loop.last %}\n", 183 " {{- \"\\n\" }}\n", 184 " {%- endif %}\n", 185 "{%- endfor %}\n", 186 "{{- \" </tools>\" }}\n", 187 "{{- 'Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}}, \"required\": [\"name\", \"arguments\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}\n", 188 "' }}\n", 189 "{{- \"For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n", 190 "\" }}\n", 191 "{{- \"<tool_call>\n", 192 "\" }}\n", 193 "{{- '{\"name\": <function-name>, \"arguments\": <args-dict>}\n", 194 "' }}\n", 195 "{{- '</tool_call><|im_end|>' }}\n", 196 "{%- for message in messages %}\n", 197 " {%- if message.role == \"user\" or message.role == \"system\" or (message.role == \"assistant\" and message.tool_calls is not defined) %}\n", 198 " {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n", 199 " {%- elif message.role == \"assistant\" %}\n", 200 " {{- '<|im_start|>' + message.role }}\n", 201 " {%- for tool_call in message.tool_calls %}\n", 202 " {{- '\n", 203 "<tool_call>\n", 204 "' }} {%- if tool_call.function is defined %}\n", 205 " {%- set tool_call = tool_call.function %}\n", 206 " {%- endif %}\n", 207 " {{- '{' }}\n", 208 " {{- '\"name\": \"' }}\n", 209 " {{- tool_call.name }}\n", 210 " {{- '\"}' }}\n", 211 " {{- ', '}}\n", 212 " {%- if tool_call.arguments is defined %}\n", 213 " {{- '\"arguments\": ' }}\n", 214 " {{- tool_call.arguments|tojson }}\n", 215 " {%- endif %}\n", 216 " {{- '\\n</tool_call>' }}\n", 217 " {%- endfor %}\n", 218 " {{- '<|im_end|>\\n' }}\n", 219 " {%- elif message.role == \"tool\" %}\n", 220 " {%- if not message.name is defined %}\n", 221 " {{- raise_exception(\"Tool response dicts require a 'name' key indicating the name of the called function!\") }}\n", 222 " {%- endif %}\n", 223 " {%- if loop.previtem and loop.previtem.role != \"tool\" %}\n", 224 " {{- '<|im_start|>tool\\n' }}\n", 225 " {%- endif %}\n", 226 " {{- '<tool_response>\\n' }}\n", 227 " {{- message.content }}\n", 228 " {%- if not loop.last %}\n", 229 " {{- '\\n</tool_response>\\n' }}\n", 230 " {%- else %}\n", 231 " {{- '\\n</tool_response>' }}\n", 232 " {%- endif %}\n", 233 " {%- if not loop.last and loop.nextitem.role != \"tool\" %}\n", 234 " {{- '<|im_end|>' }}\n", 235 " {%- elif loop.last %}\n", 236 " {{- '<|im_end|>' }}\n", 237 " {%- endif %}\n", 238 " {%- endif %}\n", 239 "{%- endfor %}\n", 240 "{%- if add_generation_prompt %}\n", 241 " {{- '<|im_start|>assistant\\n' }}\n", 242 "{%- endif %}\n", 243 "\n" 244 ] 245 } 246 ], 247 "source": [ 248 "print(tokenizer.chat_template[\"tool_use\"])" 249 ] 250 }, 251 { 252 "cell_type": "code", 253 "execution_count": 3, 254 "metadata": {}, 255 "outputs": [], 256 "source": [ 257 "import requests\n", 258 "import random\n", 259 "from datetime import datetime\n", 260 "import pytz\n", 261 "import time\n", 262 "import json\n", 263 "\n", 264 "def get_weather_forecast(location: str) -> dict[str, str]:\n", 265 " \"\"\"\n", 266 " Retrieves a simple weather forecast for a given location.\n", 267 "\n", 268 " Args:\n", 269 " location: The name of the location to get the weather forecast for.\n", 270 "\n", 271 " Returns:\n", 272 " dict[str, str]: A dictionary containing the location, forecast, and temperature.\n", 273 " If successful, the keys are 'location', 'forecast', and 'temperature'.\n", 274 " If unsuccessful, the key is 'error' with a corresponding error message.\n", 275 "\n", 276 " Raises:\n", 277 " requests.RequestException: If there's an error in making the API request.\n", 278 " \"\"\"\n", 279 " url = f\"https://wttr.in/{location}?format=%C,%t\"\n", 280 " response = requests.get(url)\n", 281 " if response.status_code == 200:\n", 282 " condition, temperature = response.text.strip().split(',')\n", 283 " return {\n", 284 " \"location\": location,\n", 285 " \"forecast\": condition,\n", 286 " \"temperature\": temperature\n", 287 " }\n", 288 " else:\n", 289 " return {\"error\": \"Unable to fetch weather data\"}\n", 290 "\n", 291 "def get_stock_price(symbol: str) -> str:\n", 292 " \"\"\"\n", 293 " Retrieves the stock price for a given symbol.\n", 294 "\n", 295 " Args:\n", 296 " symbol: The stock symbol to look up.\n", 297 "\n", 298 " Returns:\n", 299 " str: The current stock price or an error message.\n", 300 " \"\"\"\n", 301 " api_key = \"your_stock_api_key\" # Replace with your actual API key\n", 302 " url = f\"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey={api_key}\"\n", 303 " \n", 304 " try:\n", 305 " response = requests.get(url)\n", 306 " response.raise_for_status() # Raises an HTTPError for bad responses\n", 307 " data = response.json()\n", 308 " \n", 309 " if \"Global Quote\" in data and \"05. price\" in data[\"Global Quote\"]:\n", 310 " return f\"${data['Global Quote']['05. price']}\"\n", 311 " elif \"Note\" in data:\n", 312 " return f\"API limit reached: {data['Note']}\"\n", 313 " else:\n", 314 " return f\"Unable to fetch stock price for {symbol}. Response: {data}\"\n", 315 " \n", 316 " except requests.RequestException as e:\n", 317 " return f\"Error fetching stock price: {str(e)}\"\n", 318 " except (KeyError, ValueError) as e:\n", 319 " return f\"Error parsing stock price data: {str(e)}\"\n", 320 "\n", 321 "def get_random_number(min_value: int, max_value: int) -> int:\n", 322 " \"\"\"\n", 323 " Returns a random number between min_value and max_value (inclusive).\n", 324 "\n", 325 " Args:\n", 326 " min_value: The minimum value of the range.\n", 327 " max_value: The maximum value of the range.\n", 328 "\n", 329 " Returns:\n", 330 " int: A random integer between min_value and max_value (inclusive).\n", 331 " \"\"\"\n", 332 " return random.randint(min_value, max_value)\n", 333 "\n", 334 "def get_current_time(time_zone: str, format: str) -> str:\n", 335 " \"\"\"\n", 336 " Returns the current time in the specified time zone and format.\n", 337 "\n", 338 " Args:\n", 339 " time_zone: The name of the time zone (e.g., 'America/New_York').\n", 340 " format: The desired output format (e.g., '%Y-%m-%d %H:%M:%S').\n", 341 "\n", 342 " Returns:\n", 343 " str: The current time formatted according to the specified format.\n", 344 "\n", 345 " Raises:\n", 346 " pytz.exceptions.UnknownTimeZoneError: If the specified time zone is invalid.\n", 347 " \"\"\"\n", 348 " tz = pytz.timezone(time_zone)\n", 349 " current_time = datetime.now(tz)\n", 350 " return current_time.strftime(format)\n", 351 "\n", 352 "def speak_to_user(assistant_message: str) -> str:\n", 353 " \"\"\"\n", 354 " Opens a text input widget for the user to provide feedback or confirm something.\n", 355 "\n", 356 " Args:\n", 357 " assistant_message: The message to display to the user.\n", 358 "\n", 359 " Returns:\n", 360 " str: The user's input as a string.\n", 361 " \"\"\"\n", 362 " print(assistant_message)\n", 363 " user_input = input(\"Please provide your feedback or confirmation: \")\n", 364 " time.sleep(5) # Wait for 5 seconds\n", 365 " return user_input\n", 366 "\n", 367 "def get_user_location(accuracy: int) -> str:\n", 368 " \"\"\"\n", 369 " Returns the user's location based on the public IP address and accuracy level.\n", 370 " \n", 371 " Args:\n", 372 " accuracy: The level of detail for the location information.\n", 373 " 1 - Country only\n", 374 " 2 - City and country\n", 375 " 3 - City, region, and country\n", 376 " \n", 377 " Returns:\n", 378 " str: The location information based on the specified accuracy level or an error message.\n", 379 "\n", 380 " Raises:\n", 381 " requests.RequestException: If there's an error in making the API requests.\n", 382 " KeyError: If the response doesn't contain the expected data.\n", 383 " \"\"\"\n", 384 " try:\n", 385 " # Retrieve public IP address\n", 386 " ip_response = requests.get('https://api.ipify.org?format=json')\n", 387 " ip_response.raise_for_status()\n", 388 " ip_address = ip_response.json().get('ip')\n", 389 " \n", 390 " # Use public IP to get location data\n", 391 " location_url = f\"http://ip-api.com/json/{ip_address}\"\n", 392 " location_response = requests.get(location_url)\n", 393 " location_response.raise_for_status()\n", 394 " data = location_response.json()\n", 395 " \n", 396 " if data['status'] == 'fail':\n", 397 " return f\"Error in get_user_location: {data.get('message', 'Unknown error')}\"\n", 398 " \n", 399 " if accuracy == 1:\n", 400 " return data.get(\"country\", \"Unknown country\")\n", 401 " elif accuracy == 2:\n", 402 " return f\"{data.get('city', 'Unknown city')}, {data.get('country', 'Unknown country')}\"\n", 403 " elif accuracy == 3:\n", 404 " return f\"{data.get('city', 'Unknown city')}, {data.get('regionName', 'Unknown region')}, {data.get('country', 'Unknown country')}\"\n", 405 " else:\n", 406 " return \"Invalid accuracy level. Please specify 1 (Country), 2 (City and Country), or 3 (City, Region, and Country).\"\n", 407 " except requests.RequestException as e:\n", 408 " return f\"Error: {e}\"" 409 ] 410 }, 411 { 412 "cell_type": "code", 413 "execution_count": 5, 414 "metadata": {}, 415 "outputs": [], 416 "source": [ 417 "def test_functions():\n", 418 " print(\"Testing get_weather_forecast:\")\n", 419 " try:\n", 420 " weather = get_weather_forecast(\"London\")\n", 421 " print(f\"Weather in London: {weather}\")\n", 422 " except Exception as e:\n", 423 " print(f\"Error in get_weather_forecast: {str(e)}\")\n", 424 "\n", 425 " print(\"\\nTesting get_stock_price:\")\n", 426 " try:\n", 427 " price = get_stock_price(\"AAPL\")\n", 428 " print(f\"Current price of AAPL: ${price:.2f}\")\n", 429 " except Exception as e:\n", 430 " print(f\"Error in get_stock_price: {str(e)}\")\n", 431 "\n", 432 " print(\"\\nTesting get_random_number:\")\n", 433 " try:\n", 434 " number = get_random_number(2, 42)\n", 435 " print(f\"Random number between 1 and 100: {number}\")\n", 436 " except Exception as e:\n", 437 " print(f\"Error in get_random_number: {str(e)}\")\n", 438 "\n", 439 " print(\"\\nTesting get_current_time:\")\n", 440 " try:\n", 441 " time = get_current_time(\"America/New_York\", \"%Y-%m-%d %H:%M:%S\")\n", 442 " print(f\"Current time in New York: {time}\")\n", 443 " except Exception as e:\n", 444 " print(f\"Error in get_current_time: {str(e)}\")\n", 445 "\n", 446 " print(\"\\nTesting get_user_location:\")\n", 447 " try:\n", 448 " location = get_user_location(2)\n", 449 " print(f\"Location for user's ip address: {location}\")\n", 450 " except Exception as e:\n", 451 " print(f\"Error in get_user_location: {str(e)}\")" 452 ] 453 }, 454 { 455 "cell_type": "code", 456 "execution_count": 7, 457 "metadata": {}, 458 "outputs": [ 459 { 460 "name": "stdout", 461 "output_type": "stream", 462 "text": [ 463 "Testing get_weather_forecast:\n" 464 ] 465 }, 466 { 467 "name": "stdout", 468 "output_type": "stream", 469 "text": [ 470 "Weather in London: {'location': 'London', 'forecast': 'Sunny', 'temperature': '+67°F'}\n", 471 "\n", 472 "Testing get_stock_price:\n", 473 "Error in get_stock_price: Unknown format code 'f' for object of type 'str'\n", 474 "\n", 475 "Testing get_random_number:\n", 476 "Random number between 1 and 100: 25\n", 477 "\n", 478 "Testing get_current_time:\n", 479 "Current time in New York: 2024-08-13 07:01:56\n", 480 "\n", 481 "Testing get_user_location:\n", 482 "Location for user's ip address: Ashburn, United States\n" 483 ] 484 } 485 ], 486 "source": [ 487 "test_functions()" 488 ] 489 }, 490 { 491 "cell_type": "code", 492 "execution_count": 4, 493 "metadata": {}, 494 "outputs": [], 495 "source": [ 496 "tools = [\n", 497 " get_weather_forecast,\n", 498 " get_stock_price,\n", 499 " get_random_number,\n", 500 " get_current_time,\n", 501 " get_user_location,\n", 502 " speak_to_user\n", 503 "]" 504 ] 505 }, 506 { 507 "cell_type": "code", 508 "execution_count": 44, 509 "metadata": {}, 510 "outputs": [], 511 "source": [ 512 "messages = [\n", 513 " {\"role\": \"user\", \"content\": \"Get the current weather forecast for San Francisco and get stock price of Tesla\"}\n", 514 "]" 515 ] 516 }, 517 { 518 "cell_type": "code", 519 "execution_count": null, 520 "metadata": {}, 521 "outputs": [], 522 "source": [ 523 "messages = [\n", 524 " {\"role\": \"user\", \"content\": \"Get the user's location first. Once you have the correct user's location, get current weather forecast. Call the functions one at a time sequentially without commenting or asking for confirmation\"}\n", 525 "]" 526 ] 527 }, 528 { 529 "cell_type": "code", 530 "execution_count": 32, 531 "metadata": {}, 532 "outputs": [ 533 { 534 "name": "stdout", 535 "output_type": "stream", 536 "text": [ 537 "<|begin_of_text|>You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {\"type\": \"function\", \"function\": {\"name\": \"get_weather_forecast\", \"description\": \"get_weather_forecast(location: str) -> dict[str, str] - Retrieves a simple weather forecast for a given location.\n", 538 "\n", 539 " Args:\n", 540 " location(str): The name of the location to get the weather forecast for.\n", 541 " Returns:\n", 542 " dict[str, str]: A dictionary containing the location, forecast, and temperature.\n", 543 " If successful, the keys are 'location', 'forecast', and 'temperature'.\n", 544 " If unsuccessful, the key is 'error' with a corresponding error message.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"location\": {\"type\": \"string\", \"description\": \"The name of the location to get the weather forecast for.\"}}, \"required\": [\"location\"]}}\n", 545 "{\"type\": \"function\", \"function\": {\"name\": \"get_stock_price\", \"description\": \"get_stock_price(symbol: str) -> str - Retrieves the stock price for a given symbol.\n", 546 "\n", 547 " Args:\n", 548 " symbol(str): The stock symbol to look up.\n", 549 " Returns:\n", 550 " str: The current stock price or an error message.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"symbol\": {\"type\": \"string\", \"description\": \"The stock symbol to look up.\"}}, \"required\": [\"symbol\"]}}\n", 551 "{\"type\": \"function\", \"function\": {\"name\": \"get_random_number\", \"description\": \"get_random_number(min_value: int, max_value: int) -> int - Returns a random number between min_value and max_value (inclusive).\n", 552 "\n", 553 " Args:\n", 554 " min_value(int): The minimum value of the range. max_value(int): The maximum value of the range.\n", 555 " Returns:\n", 556 " int: A random integer between min_value and max_value (inclusive).\", \"parameters\": {\"type\": \"object\", \"properties\": {\"min_value\": {\"type\": \"integer\", \"description\": \"The minimum value of the range.\"}, \"max_value\": {\"type\": \"integer\", \"description\": \"The maximum value of the range.\"}}, \"required\": [\"min_value\", \"max_value\"]}}\n", 557 "{\"type\": \"function\", \"function\": {\"name\": \"get_current_time\", \"description\": \"get_current_time(time_zone: str, format: str) -> str - Returns the current time in the specified time zone and format.\n", 558 "\n", 559 " Args:\n", 560 " time_zone(str): The name of the time zone (e.g., 'America/New_York'). format(str): The desired output format (e.g., '%Y-%m-%d %H:%M:%S').\n", 561 " Returns:\n", 562 " str: The current time formatted according to the specified format.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"time_zone\": {\"type\": \"string\", \"description\": \"The name of the time zone (e.g., 'America/New_York').\"}, \"format\": {\"type\": \"string\", \"description\": \"The desired output format (e.g., '%Y-%m-%d %H:%M:%S').\"}}, \"required\": [\"time_zone\", \"format\"]}}\n", 563 "{\"type\": \"function\", \"function\": {\"name\": \"get_user_location\", \"description\": \"get_user_location(accuracy: int) -> str - Returns the user's location based on the public IP address and accuracy level.\n", 564 "\n", 565 " Args:\n", 566 " accuracy(int): The level of detail for the location information. 1 - Country only 2 - City and country 3 - City, region, and country\n", 567 " Returns:\n", 568 " str: The location information based on the specified accuracy level or an error message.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"accuracy\": {\"type\": \"integer\", \"description\": \"The level of detail for the location information. 1 - Country only 2 - City and country 3 - City, region, and country\"}}, \"required\": [\"accuracy\"]}}\n", 569 "{\"type\": \"function\", \"function\": {\"name\": \"speak_to_user\", \"description\": \"speak_to_user(assistant_message: str) -> str - Opens a text input widget for the user to provide feedback or confirm something.\n", 570 "\n", 571 " Args:\n", 572 " assistant_message(str): The message to display to the user.\n", 573 " Returns:\n", 574 " str: The user's input as a string.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"assistant_message\": {\"type\": \"string\", \"description\": \"The message to display to the user.\"}}, \"required\": [\"assistant_message\"]}} </tools>Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}}, \"required\": [\"name\", \"arguments\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}\n", 575 "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n", 576 "<tool_call>\n", 577 "{\"name\": <function-name>, \"arguments\": <args-dict>}\n", 578 "</tool_call><|im_end|><|im_start|>user\n", 579 "Get the current weather forecast for San Francisco and get stock price of Tesla<|im_end|>\n", 580 "\n" 581 ] 582 } 583 ], 584 "source": [ 585 "prompt = tokenizer.apply_chat_template(\n", 586 " messages,\n", 587 " tools=tools,\n", 588 " tokenize=False\n", 589 ")\n", 590 "print(prompt)" 591 ] 592 }, 593 { 594 "cell_type": "code", 595 "execution_count": 33, 596 "metadata": {}, 597 "outputs": [], 598 "source": [ 599 "def run_hermes_tool_inference(messages, tokenizer, model, tools):\n", 600 " inputs = tokenizer.apply_chat_template(messages, chat_template=\"tool_use\", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors=\"pt\")\n", 601 " inputs = {k: v.to(model.device) for k, v in inputs.items()}\n", 602 " out = model.generate(**inputs, max_new_tokens=512)\n", 603 " return tokenizer.decode(out[0][len(inputs[\"input_ids\"][0]):])" 604 ] 605 }, 606 { 607 "cell_type": "code", 608 "execution_count": 34, 609 "metadata": {}, 610 "outputs": [ 611 { 612 "name": "stderr", 613 "output_type": "stream", 614 "text": [ 615 "Setting `pad_token_id` to `eos_token_id`:128003 for open-end generation.\n" 616 ] 617 }, 618 { 619 "name": "stdout", 620 "output_type": "stream", 621 "text": [ 622 "<tool_call>\n", 623 "{\"name\": \"get_weather_forecast\", \"arguments\": {\"location\": \"San Francisco\"}}\n", 624 "</tool_call>\n", 625 "<tool_call>\n", 626 "{\"name\": \"get_stock_price\", \"arguments\": {\"symbol\": \"TSLA\"}}\n", 627 "</tool_call>\n", 628 "<|im_end|>\n" 629 ] 630 } 631 ], 632 "source": [ 633 "response = run_hermes_tool_inference(messages, tokenizer, model, tools)\n", 634 "print(response)" 635 ] 636 }, 637 { 638 "cell_type": "code", 639 "execution_count": 40, 640 "metadata": {}, 641 "outputs": [], 642 "source": [ 643 "import re\n", 644 "import json\n", 645 "\n", 646 "def parse_tool_calls(response_text):\n", 647 " tool_calls = []\n", 648 " pattern = r'<tool_call>\\s*(.*?)\\s*</tool_call>'\n", 649 " matches = re.findall(pattern, response_text)\n", 650 " \n", 651 " for match in matches:\n", 652 " try:\n", 653 " tool_call = json.loads(match)\n", 654 " tool_calls.append({\n", 655 " \"type\": \"function\",\n", 656 " \"function\": {\n", 657 " \"name\": tool_call[\"name\"],\n", 658 " \"arguments\": json.dumps(tool_call[\"arguments\"])\n", 659 " }\n", 660 " })\n", 661 " except json.JSONDecodeError:\n", 662 " print(f\"Failed to parse tool call: {match}\")\n", 663 " \n", 664 " return tool_calls" 665 ] 666 }, 667 { 668 "cell_type": "code", 669 "execution_count": 41, 670 "metadata": {}, 671 "outputs": [ 672 { 673 "name": "stdout", 674 "output_type": "stream", 675 "text": [ 676 "[{'type': 'function', 'function': {'name': 'get_weather_forecast', 'arguments': '{\"location\": \"San Francisco\"}'}}, {'type': 'function', 'function': {'name': 'get_stock_price', 'arguments': '{\"symbol\": \"TSLA\"}'}}]\n" 677 ] 678 } 679 ], 680 "source": [ 681 "tool_calls = parse_tool_calls(response)\n", 682 "print(tool_calls)" 683 ] 684 }, 685 { 686 "cell_type": "code", 687 "execution_count": 45, 688 "metadata": {}, 689 "outputs": [], 690 "source": [ 691 "messages.append({\"role\": \"assistant\", \"tool_calls\": tool_calls})" 692 ] 693 }, 694 { 695 "cell_type": "code", 696 "execution_count": 46, 697 "metadata": {}, 698 "outputs": [ 699 { 700 "name": "stdout", 701 "output_type": "stream", 702 "text": [ 703 "[\n", 704 " {\n", 705 " \"role\": \"user\",\n", 706 " \"content\": \"Get the current weather forecast for San Francisco and get stock price of Tesla\"\n", 707 " },\n", 708 " {\n", 709 " \"role\": \"assistant\",\n", 710 " \"tool_calls\": [\n", 711 " {\n", 712 " \"type\": \"function\",\n", 713 " \"function\": {\n", 714 " \"name\": \"get_weather_forecast\",\n", 715 " \"arguments\": \"{\\\"location\\\": \\\"San Francisco\\\"}\"\n", 716 " }\n", 717 " },\n", 718 " {\n", 719 " \"type\": \"function\",\n", 720 " \"function\": {\n", 721 " \"name\": \"get_stock_price\",\n", 722 " \"arguments\": \"{\\\"symbol\\\": \\\"TSLA\\\"}\"\n", 723 " }\n", 724 " }\n", 725 " ]\n", 726 " }\n", 727 "]\n" 728 ] 729 } 730 ], 731 "source": [ 732 "print(json.dumps(messages, indent=2))\n" 733 ] 734 }, 735 { 736 "cell_type": "code", 737 "execution_count": 47, 738 "metadata": {}, 739 "outputs": [ 740 { 741 "name": "stdout", 742 "output_type": "stream", 743 "text": [ 744 "<|begin_of_text|>You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {\"type\": \"function\", \"function\": {\"name\": \"get_weather_forecast\", \"description\": \"get_weather_forecast(location: str) -> dict[str, str] - Retrieves a simple weather forecast for a given location.\n", 745 "\n", 746 " Args:\n", 747 " location(str): The name of the location to get the weather forecast for.\n", 748 " Returns:\n", 749 " dict[str, str]: A dictionary containing the location, forecast, and temperature.\n", 750 " If successful, the keys are 'location', 'forecast', and 'temperature'.\n", 751 " If unsuccessful, the key is 'error' with a corresponding error message.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"location\": {\"type\": \"string\", \"description\": \"The name of the location to get the weather forecast for.\"}}, \"required\": [\"location\"]}}\n", 752 "{\"type\": \"function\", \"function\": {\"name\": \"get_stock_price\", \"description\": \"get_stock_price(symbol: str) -> str - Retrieves the stock price for a given symbol.\n", 753 "\n", 754 " Args:\n", 755 " symbol(str): The stock symbol to look up.\n", 756 " Returns:\n", 757 " str: The current stock price or an error message.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"symbol\": {\"type\": \"string\", \"description\": \"The stock symbol to look up.\"}}, \"required\": [\"symbol\"]}}\n", 758 "{\"type\": \"function\", \"function\": {\"name\": \"get_random_number\", \"description\": \"get_random_number(min_value: int, max_value: int) -> int - Returns a random number between min_value and max_value (inclusive).\n", 759 "\n", 760 " Args:\n", 761 " min_value(int): The minimum value of the range. max_value(int): The maximum value of the range.\n", 762 " Returns:\n", 763 " int: A random integer between min_value and max_value (inclusive).\", \"parameters\": {\"type\": \"object\", \"properties\": {\"min_value\": {\"type\": \"integer\", \"description\": \"The minimum value of the range.\"}, \"max_value\": {\"type\": \"integer\", \"description\": \"The maximum value of the range.\"}}, \"required\": [\"min_value\", \"max_value\"]}}\n", 764 "{\"type\": \"function\", \"function\": {\"name\": \"get_current_time\", \"description\": \"get_current_time(time_zone: str, format: str) -> str - Returns the current time in the specified time zone and format.\n", 765 "\n", 766 " Args:\n", 767 " time_zone(str): The name of the time zone (e.g., 'America/New_York'). format(str): The desired output format (e.g., '%Y-%m-%d %H:%M:%S').\n", 768 " Returns:\n", 769 " str: The current time formatted according to the specified format.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"time_zone\": {\"type\": \"string\", \"description\": \"The name of the time zone (e.g., 'America/New_York').\"}, \"format\": {\"type\": \"string\", \"description\": \"The desired output format (e.g., '%Y-%m-%d %H:%M:%S').\"}}, \"required\": [\"time_zone\", \"format\"]}}\n", 770 "{\"type\": \"function\", \"function\": {\"name\": \"get_user_location\", \"description\": \"get_user_location(accuracy: int) -> str - Returns the user's location based on the public IP address and accuracy level.\n", 771 "\n", 772 " Args:\n", 773 " accuracy(int): The level of detail for the location information. 1 - Country only 2 - City and country 3 - City, region, and country\n", 774 " Returns:\n", 775 " str: The location information based on the specified accuracy level or an error message.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"accuracy\": {\"type\": \"integer\", \"description\": \"The level of detail for the location information. 1 - Country only 2 - City and country 3 - City, region, and country\"}}, \"required\": [\"accuracy\"]}}\n", 776 "{\"type\": \"function\", \"function\": {\"name\": \"speak_to_user\", \"description\": \"speak_to_user(assistant_message: str) -> str - Opens a text input widget for the user to provide feedback or confirm something.\n", 777 "\n", 778 " Args:\n", 779 " assistant_message(str): The message to display to the user.\n", 780 " Returns:\n", 781 " str: The user's input as a string.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"assistant_message\": {\"type\": \"string\", \"description\": \"The message to display to the user.\"}}, \"required\": [\"assistant_message\"]}} </tools>Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}}, \"required\": [\"name\", \"arguments\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}\n", 782 "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n", 783 "<tool_call>\n", 784 "{\"name\": <function-name>, \"arguments\": <args-dict>}\n", 785 "</tool_call><|im_end|><|im_start|>user\n", 786 "Get the current weather forecast for San Francisco and get stock price of Tesla<|im_end|>\n", 787 "<|im_start|>assistant\n", 788 "<tool_call>\n", 789 "{\"name\": \"get_weather_forecast\"}, \"arguments\": \"{\\\"location\\\": \\\"San Francisco\\\"}\"\n", 790 "</tool_call>\n", 791 "<tool_call>\n", 792 "{\"name\": \"get_stock_price\"}, \"arguments\": \"{\\\"symbol\\\": \\\"TSLA\\\"}\"\n", 793 "</tool_call><|im_end|>\n", 794 "\n" 795 ] 796 } 797 ], 798 "source": [ 799 "\n", 800 "prompt = tokenizer.apply_chat_template(\n", 801 " messages,\n", 802 " tools=tools,\n", 803 " tokenize=False\n", 804 ")\n", 805 "print(prompt)" 806 ] 807 }, 808 { 809 "cell_type": "code", 810 "execution_count": 48, 811 "metadata": {}, 812 "outputs": [ 813 { 814 "name": "stdout", 815 "output_type": "stream", 816 "text": [ 817 "Invoking tool call: get_weather_forecast\n" 818 ] 819 }, 820 { 821 "name": "stdout", 822 "output_type": "stream", 823 "text": [ 824 "Invoking tool call: get_stock_price\n" 825 ] 826 } 827 ], 828 "source": [ 829 " for tool_call in tool_calls:\n", 830 " function_call = tool_call[\"function\"]\n", 831 " name = function_call[\"name\"]\n", 832 " arguments = json.loads(function_call[\"arguments\"])\n", 833 " for function in tools:\n", 834 " if function.__name__ == name:\n", 835 " print(f\"Invoking tool call: {name}\")\n", 836 " result = function(**arguments)\n", 837 " result_content = {\n", 838 " \"name\": name,\n", 839 " \"content\": result\n", 840 " }\n", 841 " messages.append(\n", 842 " {\n", 843 " \"role\": \"tool\",\n", 844 " \"content\": json.dumps(result_content),\n", 845 " \"name\": name\n", 846 " }\n", 847 " )" 848 ] 849 }, 850 { 851 "cell_type": "code", 852 "execution_count": 49, 853 "metadata": {}, 854 "outputs": [ 855 { 856 "name": "stdout", 857 "output_type": "stream", 858 "text": [ 859 "[\n", 860 " {\n", 861 " \"role\": \"user\",\n", 862 " \"content\": \"Get the current weather forecast for San Francisco and get stock price of Tesla\"\n", 863 " },\n", 864 " {\n", 865 " \"role\": \"assistant\",\n", 866 " \"tool_calls\": [\n", 867 " {\n", 868 " \"type\": \"function\",\n", 869 " \"function\": {\n", 870 " \"name\": \"get_weather_forecast\",\n", 871 " \"arguments\": \"{\\\"location\\\": \\\"San Francisco\\\"}\"\n", 872 " }\n", 873 " },\n", 874 " {\n", 875 " \"type\": \"function\",\n", 876 " \"function\": {\n", 877 " \"name\": \"get_stock_price\",\n", 878 " \"arguments\": \"{\\\"symbol\\\": \\\"TSLA\\\"}\"\n", 879 " }\n", 880 " }\n", 881 " ]\n", 882 " },\n", 883 " {\n", 884 " \"role\": \"tool\",\n", 885 " \"content\": \"{\\\"name\\\": \\\"get_weather_forecast\\\", \\\"content\\\": {\\\"location\\\": \\\"San Francisco\\\", \\\"forecast\\\": \\\"Clear \\\", \\\"temperature\\\": \\\"+55\\\\u00b0F\\\"}}\",\n", 886 " \"name\": \"get_weather_forecast\"\n", 887 " },\n", 888 " {\n", 889 " \"role\": \"tool\",\n", 890 " \"content\": \"{\\\"name\\\": \\\"get_stock_price\\\", \\\"content\\\": \\\"$197.4900\\\"}\",\n", 891 " \"name\": \"get_stock_price\"\n", 892 " }\n", 893 "]\n" 894 ] 895 } 896 ], 897 "source": [ 898 "print(json.dumps(messages, indent=2))" 899 ] 900 }, 901 { 902 "cell_type": "code", 903 "execution_count": 50, 904 "metadata": {}, 905 "outputs": [ 906 { 907 "name": "stdout", 908 "output_type": "stream", 909 "text": [ 910 "<|begin_of_text|>You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> {\"type\": \"function\", \"function\": {\"name\": \"get_weather_forecast\", \"description\": \"get_weather_forecast(location: str) -> dict[str, str] - Retrieves a simple weather forecast for a given location.\n", 911 "\n", 912 " Args:\n", 913 " location(str): The name of the location to get the weather forecast for.\n", 914 " Returns:\n", 915 " dict[str, str]: A dictionary containing the location, forecast, and temperature.\n", 916 " If successful, the keys are 'location', 'forecast', and 'temperature'.\n", 917 " If unsuccessful, the key is 'error' with a corresponding error message.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"location\": {\"type\": \"string\", \"description\": \"The name of the location to get the weather forecast for.\"}}, \"required\": [\"location\"]}}\n", 918 "{\"type\": \"function\", \"function\": {\"name\": \"get_stock_price\", \"description\": \"get_stock_price(symbol: str) -> str - Retrieves the stock price for a given symbol.\n", 919 "\n", 920 " Args:\n", 921 " symbol(str): The stock symbol to look up.\n", 922 " Returns:\n", 923 " str: The current stock price or an error message.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"symbol\": {\"type\": \"string\", \"description\": \"The stock symbol to look up.\"}}, \"required\": [\"symbol\"]}}\n", 924 "{\"type\": \"function\", \"function\": {\"name\": \"get_random_number\", \"description\": \"get_random_number(min_value: int, max_value: int) -> int - Returns a random number between min_value and max_value (inclusive).\n", 925 "\n", 926 " Args:\n", 927 " min_value(int): The minimum value of the range. max_value(int): The maximum value of the range.\n", 928 " Returns:\n", 929 " int: A random integer between min_value and max_value (inclusive).\", \"parameters\": {\"type\": \"object\", \"properties\": {\"min_value\": {\"type\": \"integer\", \"description\": \"The minimum value of the range.\"}, \"max_value\": {\"type\": \"integer\", \"description\": \"The maximum value of the range.\"}}, \"required\": [\"min_value\", \"max_value\"]}}\n", 930 "{\"type\": \"function\", \"function\": {\"name\": \"get_current_time\", \"description\": \"get_current_time(time_zone: str, format: str) -> str - Returns the current time in the specified time zone and format.\n", 931 "\n", 932 " Args:\n", 933 " time_zone(str): The name of the time zone (e.g., 'America/New_York'). format(str): The desired output format (e.g., '%Y-%m-%d %H:%M:%S').\n", 934 " Returns:\n", 935 " str: The current time formatted according to the specified format.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"time_zone\": {\"type\": \"string\", \"description\": \"The name of the time zone (e.g., 'America/New_York').\"}, \"format\": {\"type\": \"string\", \"description\": \"The desired output format (e.g., '%Y-%m-%d %H:%M:%S').\"}}, \"required\": [\"time_zone\", \"format\"]}}\n", 936 "{\"type\": \"function\", \"function\": {\"name\": \"get_user_location\", \"description\": \"get_user_location(accuracy: int) -> str - Returns the user's location based on the public IP address and accuracy level.\n", 937 "\n", 938 " Args:\n", 939 " accuracy(int): The level of detail for the location information. 1 - Country only 2 - City and country 3 - City, region, and country\n", 940 " Returns:\n", 941 " str: The location information based on the specified accuracy level or an error message.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"accuracy\": {\"type\": \"integer\", \"description\": \"The level of detail for the location information. 1 - Country only 2 - City and country 3 - City, region, and country\"}}, \"required\": [\"accuracy\"]}}\n", 942 "{\"type\": \"function\", \"function\": {\"name\": \"speak_to_user\", \"description\": \"speak_to_user(assistant_message: str) -> str - Opens a text input widget for the user to provide feedback or confirm something.\n", 943 "\n", 944 " Args:\n", 945 " assistant_message(str): The message to display to the user.\n", 946 " Returns:\n", 947 " str: The user's input as a string.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"assistant_message\": {\"type\": \"string\", \"description\": \"The message to display to the user.\"}}, \"required\": [\"assistant_message\"]}} </tools>Use the following pydantic model json schema for each tool call you will make: {\"properties\": {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"arguments\": {\"title\": \"Arguments\", \"type\": \"object\"}}, \"required\": [\"name\", \"arguments\"], \"title\": \"FunctionCall\", \"type\": \"object\"}}\n", 948 "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n", 949 "<tool_call>\n", 950 "{\"name\": <function-name>, \"arguments\": <args-dict>}\n", 951 "</tool_call><|im_end|><|im_start|>user\n", 952 "Get the current weather forecast for San Francisco and get stock price of Tesla<|im_end|>\n", 953 "<|im_start|>assistant\n", 954 "<tool_call>\n", 955 "{\"name\": \"get_weather_forecast\"}, \"arguments\": \"{\\\"location\\\": \\\"San Francisco\\\"}\"\n", 956 "</tool_call>\n", 957 "<tool_call>\n", 958 "{\"name\": \"get_stock_price\"}, \"arguments\": \"{\\\"symbol\\\": \\\"TSLA\\\"}\"\n", 959 "</tool_call><|im_end|>\n", 960 "<|im_start|>tool\n", 961 "<tool_response>\n", 962 "{\"name\": \"get_weather_forecast\", \"content\": {\"location\": \"San Francisco\", \"forecast\": \"Clear \", \"temperature\": \"+55\\u00b0F\"}}\n", 963 "</tool_response>\n", 964 "<tool_response>\n", 965 "{\"name\": \"get_stock_price\", \"content\": \"$197.4900\"}\n", 966 "</tool_response><|im_end|>\n" 967 ] 968 } 969 ], 970 "source": [ 971 "\n", 972 "prompt = tokenizer.apply_chat_template(\n", 973 " messages,\n", 974 " tools=tools,\n", 975 " tokenize=False\n", 976 ")\n", 977 "print(prompt)" 978 ] 979 }, 980 { 981 "cell_type": "code", 982 "execution_count": 51, 983 "metadata": {}, 984 "outputs": [], 985 "source": [ 986 "def recursive_tool_calling(messages, tokenizer, model, tools):\n", 987 " while True:\n", 988 " assistant_message = run_hermes_tool_inference(messages, tokenizer, model, tools)\n", 989 " \n", 990 " tool_calls = parse_tool_calls(assistant_message)\n", 991 " \n", 992 " if tool_calls:\n", 993 " messages.append({\"role\": \"assistant\", \"tool_calls\": tool_calls})\n", 994 " else:\n", 995 " messages.append({\"role\": \"assistant\", \"content\": assistant_message.strip(\"<|im_end|>\")})\n", 996 " \n", 997 " print(f\"Assistant Message: {assistant_message}\")\n", 998 "\n", 999 " if not tool_calls:\n", 1000 " break\n", 1001 "\n", 1002 " for tool_call in tool_calls:\n", 1003 " function_call = tool_call[\"function\"]\n", 1004 " name = function_call[\"name\"]\n", 1005 " arguments = json.loads(function_call[\"arguments\"])\n", 1006 " for function in tools:\n", 1007 " if function.__name__ == name:\n", 1008 " print(f\"Invoking tool call: {name}\")\n", 1009 " result = function(**arguments)\n", 1010 " result_content = {\n", 1011 " \"name\": name,\n", 1012 " \"content\": result\n", 1013 " }\n", 1014 " messages.append(\n", 1015 " {\n", 1016 " \"role\": \"tool\",\n", 1017 " \"content\": json.dumps(result_content),\n", 1018 " \"name\": name\n", 1019 " }\n", 1020 " )\n", 1021 " print(f\"Tool Call Result: {result_content}\")\n", 1022 " break\n", 1023 "\n", 1024 " return messages" 1025 ] 1026 }, 1027 { 1028 "cell_type": "code", 1029 "execution_count": 52, 1030 "metadata": {}, 1031 "outputs": [ 1032 { 1033 "name": "stderr", 1034 "output_type": "stream", 1035 "text": [ 1036 "Setting `pad_token_id` to `eos_token_id`:128003 for open-end generation.\n" 1037 ] 1038 }, 1039 { 1040 "name": "stdout", 1041 "output_type": "stream", 1042 "text": [ 1043 "Assistant Message: The current weather forecast for San Francisco is \"Clear\" and the temperature is +55°F. The stock price for Tesla (TSLA) is $197.49.<|im_end|>\n" 1044 ] 1045 } 1046 ], 1047 "source": [ 1048 "messages = recursive_tool_calling(messages, tokenizer, model, tools)\n" 1049 ] 1050 }, 1051 { 1052 "cell_type": "code", 1053 "execution_count": null, 1054 "metadata": {}, 1055 "outputs": [], 1056 "source": [] 1057 } 1058 ], 1059 "metadata": { 1060 "language_info": { 1061 "name": "python" 1062 } 1063 }, 1064 "nbformat": 4, 1065 "nbformat_minor": 2 1066 }