/ mlflow / types / chat.py
chat.py
  1  from __future__ import annotations
  2  
  3  from typing import Annotated, Any, Literal
  4  from uuid import uuid4
  5  
  6  from pydantic import BaseModel, Field, model_serializer
  7  
  8  
  9  class TextContentPart(BaseModel):
 10      type: Literal["text"]
 11      text: str
 12  
 13  
 14  class ImageUrl(BaseModel):
 15      """
 16      Represents an image URL.
 17  
 18      Attributes:
 19          url: Either a URL of an image or base64 encoded data.
 20              https://platform.openai.com/docs/guides/vision?lang=curl#uploading-base64-encoded-images
 21          detail: The level of resolution for the image when the model receives it.
 22              For example, when set to "low", the model will see a image resized to
 23              512x512 pixels, which consumes fewer tokens. In OpenAI, this is optional
 24              and defaults to "auto".
 25              https://platform.openai.com/docs/guides/vision?lang=curl#low-or-high-fidelity-image-understanding
 26      """
 27  
 28      url: str
 29      detail: Literal["auto", "low", "high"] | None = None
 30  
 31  
 32  class ImageContentPart(BaseModel):
 33      type: Literal["image_url"]
 34      image_url: ImageUrl
 35  
 36  
 37  class InputAudio(BaseModel):
 38      data: str  # base64 encoded data
 39      format: Literal["wav", "mp3"]
 40  
 41  
 42  class AudioContentPart(BaseModel):
 43      type: Literal["input_audio"]
 44      input_audio: InputAudio
 45  
 46  
 47  ContentPartsList = list[
 48      Annotated[TextContentPart | ImageContentPart | AudioContentPart, Field(discriminator="type")]
 49  ]
 50  
 51  
 52  ContentType = Annotated[str | ContentPartsList, Field(union_mode="left_to_right")]
 53  
 54  
 55  class Function(BaseModel):
 56      name: str | None = None
 57      arguments: str | None = None
 58  
 59      def to_tool_call(self, id=None) -> ToolCall:
 60          if id is None:
 61              id = str(uuid4())
 62          return ToolCall(id=id, type="function", function=self)
 63  
 64  
 65  class ToolCall(BaseModel):
 66      id: str
 67      type: str = Field(default="function")
 68      function: Function
 69  
 70  
 71  class ChatMessage(BaseModel):
 72      """
 73      A chat request. ``content`` can be a string, or an array of content parts.
 74  
 75      A content part is one of the following:
 76  
 77      - :py:class:`TextContentPart <mlflow.types.chat.TextContentPart>`
 78      - :py:class:`ImageContentPart <mlflow.types.chat.ImageContentPart>`
 79      - :py:class:`AudioContentPart <mlflow.types.chat.AudioContentPart>`
 80      """
 81  
 82      role: str
 83      content: ContentType | None = None
 84      # NB: In the actual OpenAI chat completion API spec, these fields only
 85      #   present in either the request or response message (tool_call_id is only in
 86      #   the request, while the other two are only in the response).
 87      #   Strictly speaking, we should separate the request and response message types
 88      #   to match OpenAI's API spec. However, we don't want to do that because we the
 89      #   request and response message types are not distinguished in many parts of the
 90      #   codebase, and also we don't want to ask users to use two different classes.
 91      #   Therefore, we include all fields in this class, while marking them as optional.
 92      # TODO: Define a sub classes for different type of messages (request/response, and
 93      #   system/user/assistant/tool, etc), and create a factory function to allow users
 94      #   to create them without worrying about the details.
 95      tool_calls: list[ToolCall] | None = None
 96      refusal: str | None = None
 97      tool_call_id: str | None = None
 98  
 99  
100  AllowedType = Literal["string", "number", "integer", "object", "array", "boolean", "null"]
101  
102  
103  class ParamType(BaseModel):
104      type: AllowedType | list[AllowedType] | None = None
105  
106  
107  class ParamProperty(ParamType):
108      """
109      OpenAI uses JSON Schema (https://json-schema.org/) for function parameters.
110      See OpenAI function calling reference:
111      https://platform.openai.com/docs/guides/function-calling?&api-mode=responses#defining-functions
112  
113      JSON Schema enum supports any JSON type (str, int, float, bool, null, arrays, objects),
114      but we restrict to basic scalar types for practical use cases and API safety.
115      """
116  
117      description: str | None = None
118      enum: list[str | int | float | bool] | None = None
119      items: ParamType | None = None
120  
121  
122  class FunctionParams(BaseModel):
123      properties: dict[str, ParamProperty]
124      type: Literal["object"] = "object"
125      required: list[str] | None = None
126      additionalProperties: bool | None = None
127  
128  
129  class FunctionToolDefinition(BaseModel):
130      name: str
131      description: str | None = None
132      parameters: FunctionParams | None = None
133      strict: bool | None = None
134  
135  
136  class ChatTool(BaseModel):
137      """
138      A tool definition passed to the chat completion API.
139  
140      Ref: https://platform.openai.com/docs/guides/function-calling
141      """
142  
143      type: Literal["function"]
144      function: FunctionToolDefinition | None = None
145  
146  
147  class ResponseFormat(BaseModel):
148      """
149      Response format configuration for structured outputs.
150  
151      Supported formats: {"type": "json_schema", "json_schema": {...}}.
152  
153      The schema should follow JSON Schema specification.
154      """
155  
156      type: Literal["text", "json_object", "json_schema"]
157      json_schema: dict[str, Any] | None = None
158  
159  
160  class ToolChoiceFunction(BaseModel):
161      """Specifies a tool the model should use."""
162  
163      name: str
164  
165  
166  class ToolChoice(BaseModel):
167      """
168      Specifies a particular tool to use.
169  
170      OpenAI format: {"type": "function", "function": {"name": "my_function"}}
171      """
172  
173      type: Literal["function"]
174      function: ToolChoiceFunction
175  
176  
177  class BaseRequestPayload(BaseModel):
178      """Common parameters used for chat completions and completion endpoints."""
179  
180      n: int = Field(1, ge=1)
181      stop: list[str] | None = Field(None, min_length=1)
182      max_tokens: int | None = Field(None, ge=1)
183      max_completion_tokens: int | None = Field(None, ge=1)
184      stream: bool | None = None
185      stream_options: dict[str, Any] | None = None
186      model: str | None = None
187      response_format: ResponseFormat | None = None
188      temperature: float | None = Field(None, ge=0, le=2)
189      top_p: float | None = Field(None, ge=0, le=1)
190      presence_penalty: float | None = Field(None, ge=-2, le=2)
191      frequency_penalty: float | None = Field(None, ge=-2, le=2)
192      top_k: int | None = Field(None, ge=1)
193  
194  
195  # NB: For interface constructs that rely on other BaseModel implementations, in
196  # pydantic 1 the **order** in which classes are defined in this module is absolutely
197  # critical to prevent ForwardRef errors. Pydantic 2 does not have this limitation.
198  # To maintain compatibility with Pydantic 1, ensure that all classes that are defined in
199  # this file have dependencies defined higher than the line of usage.
200  
201  
202  class ChatChoice(BaseModel):
203      index: int
204      message: ChatMessage
205      finish_reason: str | None = None
206  
207  
208  class PromptTokensDetails(BaseModel):
209      model_config = {"extra": "allow"}
210  
211      cached_tokens: int | None = None
212  
213  
214  class ChatUsage(BaseModel):
215      model_config = {"extra": "allow"}
216  
217      prompt_tokens: int | None = None
218      completion_tokens: int | None = None
219      total_tokens: int | None = None
220      prompt_tokens_details: PromptTokensDetails | None = None
221  
222      @model_serializer(mode="wrap")
223      def _serialize(self, handler):
224          data = handler(self)
225          if data.get("prompt_tokens_details") is None:
226              data.pop("prompt_tokens_details", None)
227          return data
228  
229  
230  class ToolCallDelta(BaseModel):
231      index: int
232      id: str | None = None
233      type: str | None = None
234      function: Function
235  
236  
237  class ChatChoiceDelta(BaseModel):
238      role: str | None = None
239      content: str | None = None
240      tool_calls: list[ToolCallDelta] | None = None
241  
242  
243  class ChatChunkChoice(BaseModel):
244      index: int
245      finish_reason: str | None = None
246      delta: ChatChoiceDelta
247  
248  
249  class ChatCompletionChunk(BaseModel):
250      """A chunk of a chat completion stream response."""
251  
252      id: str | None = None
253      object: str = "chat.completion.chunk"
254      created: int
255      model: str
256      choices: list[ChatChunkChoice]
257      usage: ChatUsage | None = None
258  
259  
260  class ChatCompletionRequest(BaseRequestPayload):
261      """
262      A request to the chat completion API.
263  
264      Must be compatible with OpenAI's Chat Completion API.
265      https://platform.openai.com/docs/api-reference/chat
266      """
267  
268      messages: list[ChatMessage] = Field(..., min_length=1)
269      tools: list[ChatTool] | None = Field(None, min_length=1)
270      tool_choice: Literal["none", "auto", "required"] | ToolChoice | None = None
271  
272  
273  class ChatCompletionResponse(BaseModel):
274      """
275      A response from the chat completion API.
276  
277      Must be compatible with OpenAI's Chat Completion API.
278      https://platform.openai.com/docs/api-reference/chat
279      """
280  
281      id: str | None = None
282      object: str = "chat.completion"
283      created: int
284      model: str
285      choices: list[ChatChoice]
286      usage: ChatUsage