chat.py
1 from __future__ import annotations 2 3 from typing import Annotated, Any, Literal 4 from uuid import uuid4 5 6 from pydantic import BaseModel, Field, model_serializer 7 8 9 class TextContentPart(BaseModel): 10 type: Literal["text"] 11 text: str 12 13 14 class ImageUrl(BaseModel): 15 """ 16 Represents an image URL. 17 18 Attributes: 19 url: Either a URL of an image or base64 encoded data. 20 https://platform.openai.com/docs/guides/vision?lang=curl#uploading-base64-encoded-images 21 detail: The level of resolution for the image when the model receives it. 22 For example, when set to "low", the model will see a image resized to 23 512x512 pixels, which consumes fewer tokens. In OpenAI, this is optional 24 and defaults to "auto". 25 https://platform.openai.com/docs/guides/vision?lang=curl#low-or-high-fidelity-image-understanding 26 """ 27 28 url: str 29 detail: Literal["auto", "low", "high"] | None = None 30 31 32 class ImageContentPart(BaseModel): 33 type: Literal["image_url"] 34 image_url: ImageUrl 35 36 37 class InputAudio(BaseModel): 38 data: str # base64 encoded data 39 format: Literal["wav", "mp3"] 40 41 42 class AudioContentPart(BaseModel): 43 type: Literal["input_audio"] 44 input_audio: InputAudio 45 46 47 ContentPartsList = list[ 48 Annotated[TextContentPart | ImageContentPart | AudioContentPart, Field(discriminator="type")] 49 ] 50 51 52 ContentType = Annotated[str | ContentPartsList, Field(union_mode="left_to_right")] 53 54 55 class Function(BaseModel): 56 name: str | None = None 57 arguments: str | None = None 58 59 def to_tool_call(self, id=None) -> ToolCall: 60 if id is None: 61 id = str(uuid4()) 62 return ToolCall(id=id, type="function", function=self) 63 64 65 class ToolCall(BaseModel): 66 id: str 67 type: str = Field(default="function") 68 function: Function 69 70 71 class ChatMessage(BaseModel): 72 """ 73 A chat request. ``content`` can be a string, or an array of content parts. 74 75 A content part is one of the following: 76 77 - :py:class:`TextContentPart <mlflow.types.chat.TextContentPart>` 78 - :py:class:`ImageContentPart <mlflow.types.chat.ImageContentPart>` 79 - :py:class:`AudioContentPart <mlflow.types.chat.AudioContentPart>` 80 """ 81 82 role: str 83 content: ContentType | None = None 84 # NB: In the actual OpenAI chat completion API spec, these fields only 85 # present in either the request or response message (tool_call_id is only in 86 # the request, while the other two are only in the response). 87 # Strictly speaking, we should separate the request and response message types 88 # to match OpenAI's API spec. However, we don't want to do that because we the 89 # request and response message types are not distinguished in many parts of the 90 # codebase, and also we don't want to ask users to use two different classes. 91 # Therefore, we include all fields in this class, while marking them as optional. 92 # TODO: Define a sub classes for different type of messages (request/response, and 93 # system/user/assistant/tool, etc), and create a factory function to allow users 94 # to create them without worrying about the details. 95 tool_calls: list[ToolCall] | None = None 96 refusal: str | None = None 97 tool_call_id: str | None = None 98 99 100 AllowedType = Literal["string", "number", "integer", "object", "array", "boolean", "null"] 101 102 103 class ParamType(BaseModel): 104 type: AllowedType | list[AllowedType] | None = None 105 106 107 class ParamProperty(ParamType): 108 """ 109 OpenAI uses JSON Schema (https://json-schema.org/) for function parameters. 110 See OpenAI function calling reference: 111 https://platform.openai.com/docs/guides/function-calling?&api-mode=responses#defining-functions 112 113 JSON Schema enum supports any JSON type (str, int, float, bool, null, arrays, objects), 114 but we restrict to basic scalar types for practical use cases and API safety. 115 """ 116 117 description: str | None = None 118 enum: list[str | int | float | bool] | None = None 119 items: ParamType | None = None 120 121 122 class FunctionParams(BaseModel): 123 properties: dict[str, ParamProperty] 124 type: Literal["object"] = "object" 125 required: list[str] | None = None 126 additionalProperties: bool | None = None 127 128 129 class FunctionToolDefinition(BaseModel): 130 name: str 131 description: str | None = None 132 parameters: FunctionParams | None = None 133 strict: bool | None = None 134 135 136 class ChatTool(BaseModel): 137 """ 138 A tool definition passed to the chat completion API. 139 140 Ref: https://platform.openai.com/docs/guides/function-calling 141 """ 142 143 type: Literal["function"] 144 function: FunctionToolDefinition | None = None 145 146 147 class ResponseFormat(BaseModel): 148 """ 149 Response format configuration for structured outputs. 150 151 Supported formats: {"type": "json_schema", "json_schema": {...}}. 152 153 The schema should follow JSON Schema specification. 154 """ 155 156 type: Literal["text", "json_object", "json_schema"] 157 json_schema: dict[str, Any] | None = None 158 159 160 class ToolChoiceFunction(BaseModel): 161 """Specifies a tool the model should use.""" 162 163 name: str 164 165 166 class ToolChoice(BaseModel): 167 """ 168 Specifies a particular tool to use. 169 170 OpenAI format: {"type": "function", "function": {"name": "my_function"}} 171 """ 172 173 type: Literal["function"] 174 function: ToolChoiceFunction 175 176 177 class BaseRequestPayload(BaseModel): 178 """Common parameters used for chat completions and completion endpoints.""" 179 180 n: int = Field(1, ge=1) 181 stop: list[str] | None = Field(None, min_length=1) 182 max_tokens: int | None = Field(None, ge=1) 183 max_completion_tokens: int | None = Field(None, ge=1) 184 stream: bool | None = None 185 stream_options: dict[str, Any] | None = None 186 model: str | None = None 187 response_format: ResponseFormat | None = None 188 temperature: float | None = Field(None, ge=0, le=2) 189 top_p: float | None = Field(None, ge=0, le=1) 190 presence_penalty: float | None = Field(None, ge=-2, le=2) 191 frequency_penalty: float | None = Field(None, ge=-2, le=2) 192 top_k: int | None = Field(None, ge=1) 193 194 195 # NB: For interface constructs that rely on other BaseModel implementations, in 196 # pydantic 1 the **order** in which classes are defined in this module is absolutely 197 # critical to prevent ForwardRef errors. Pydantic 2 does not have this limitation. 198 # To maintain compatibility with Pydantic 1, ensure that all classes that are defined in 199 # this file have dependencies defined higher than the line of usage. 200 201 202 class ChatChoice(BaseModel): 203 index: int 204 message: ChatMessage 205 finish_reason: str | None = None 206 207 208 class PromptTokensDetails(BaseModel): 209 model_config = {"extra": "allow"} 210 211 cached_tokens: int | None = None 212 213 214 class ChatUsage(BaseModel): 215 model_config = {"extra": "allow"} 216 217 prompt_tokens: int | None = None 218 completion_tokens: int | None = None 219 total_tokens: int | None = None 220 prompt_tokens_details: PromptTokensDetails | None = None 221 222 @model_serializer(mode="wrap") 223 def _serialize(self, handler): 224 data = handler(self) 225 if data.get("prompt_tokens_details") is None: 226 data.pop("prompt_tokens_details", None) 227 return data 228 229 230 class ToolCallDelta(BaseModel): 231 index: int 232 id: str | None = None 233 type: str | None = None 234 function: Function 235 236 237 class ChatChoiceDelta(BaseModel): 238 role: str | None = None 239 content: str | None = None 240 tool_calls: list[ToolCallDelta] | None = None 241 242 243 class ChatChunkChoice(BaseModel): 244 index: int 245 finish_reason: str | None = None 246 delta: ChatChoiceDelta 247 248 249 class ChatCompletionChunk(BaseModel): 250 """A chunk of a chat completion stream response.""" 251 252 id: str | None = None 253 object: str = "chat.completion.chunk" 254 created: int 255 model: str 256 choices: list[ChatChunkChoice] 257 usage: ChatUsage | None = None 258 259 260 class ChatCompletionRequest(BaseRequestPayload): 261 """ 262 A request to the chat completion API. 263 264 Must be compatible with OpenAI's Chat Completion API. 265 https://platform.openai.com/docs/api-reference/chat 266 """ 267 268 messages: list[ChatMessage] = Field(..., min_length=1) 269 tools: list[ChatTool] | None = Field(None, min_length=1) 270 tool_choice: Literal["none", "auto", "required"] | ToolChoice | None = None 271 272 273 class ChatCompletionResponse(BaseModel): 274 """ 275 A response from the chat completion API. 276 277 Must be compatible with OpenAI's Chat Completion API. 278 https://platform.openai.com/docs/api-reference/chat 279 """ 280 281 id: str | None = None 282 object: str = "chat.completion" 283 created: int 284 model: str 285 choices: list[ChatChoice] 286 usage: ChatUsage