run.py
1 import argparse 2 import json 3 4 import openai 5 6 7 def parse_args(): 8 parser = argparse.ArgumentParser() 9 parser.add_argument( 10 "--uc-function-name", 11 type=str, 12 required=True, 13 help="Name of the UC function to use", 14 ) 15 return parser.parse_args() 16 17 18 def main(): 19 args = parse_args() 20 client = openai.OpenAI(base_url="http://localhost:7000/v1") 21 22 print("----- UC function -----") 23 uc_function = { 24 "type": "uc_function", 25 "uc_function": { 26 "name": args.uc_function_name, 27 }, 28 } 29 30 resp = client.chat.completions.create( 31 model="chat", 32 messages=[ 33 { 34 "role": "user", 35 "content": "What is the result of 1 + 2?", 36 } 37 ], 38 tools=[uc_function], 39 ) 40 print(resp.choices[0].message.content) 41 42 print("----- UC function + User-defined function -----") 43 user_defined_function = { 44 "type": "function", 45 "function": { 46 "description": "Multiply numbers", 47 "name": "multiply", 48 "parameters": { 49 "type": "object", 50 "properties": { 51 "x": { 52 "type": "integer", 53 "description": "First number", 54 }, 55 "y": { 56 "type": "integer", 57 "description": "Second number", 58 }, 59 }, 60 "required": ["x", "y"], 61 }, 62 }, 63 } 64 65 def multiply(x: int, y: int) -> int: 66 return x * y 67 68 msg = { 69 "role": "user", 70 "content": ( 71 "What is the result of 1 + 2? What is the result of 3 + 4? What is the result of 5 * 6?" 72 ), 73 } 74 resp = client.chat.completions.create( 75 model="chat", 76 messages=[msg], 77 tools=[ 78 user_defined_function, 79 uc_function, 80 ], 81 ) 82 83 print(resp.choices[0].message.content) 84 print(resp.choices[0].message.tool_calls) 85 86 multiply_call = resp.choices[0].message.tool_calls[0].function 87 assert multiply_call.name == "multiply" 88 resp = client.chat.completions.create( 89 model="chat", 90 messages=[ 91 msg, 92 { 93 "role": "assistant", 94 "content": resp.choices[0].message.content, 95 }, 96 { 97 "role": "assistant", 98 "content": "", 99 "tool_calls": resp.choices[0].message.tool_calls, 100 }, 101 { 102 "role": "tool", 103 "tool_call_id": resp.choices[0].message.tool_calls[0].id, 104 "content": str(multiply(**json.loads(multiply_call.arguments))), 105 }, 106 ], 107 ) 108 109 print(resp.choices[0].message.content) 110 111 112 if __name__ == "__main__": 113 main()