/ tests / mistral / test_mistral_autolog.py
test_mistral_autolog.py
  1  from unittest.mock import patch
  2  
  3  import httpx
  4  import pytest
  5  
  6  try:
  7      from mistralai.client import Mistral  # mistralai >= 2.0
  8      from mistralai.client.models import (
  9          AssistantMessage,
 10          ChatCompletionChoice,
 11          ChatCompletionResponse,
 12          FunctionCall,
 13          ToolCall,
 14          UsageInfo,
 15      )
 16  
 17      CHAT_DO_REQUEST_PATH = "mistralai.client.chat.Chat.do_request"
 18      CHAT_DO_REQUEST_ASYNC_PATH = "mistralai.client.chat.Chat.do_request_async"
 19  except ImportError:
 20      from mistralai import Mistral  # mistralai < 2.0
 21      from mistralai.models import (
 22          AssistantMessage,
 23          ChatCompletionChoice,
 24          ChatCompletionResponse,
 25          FunctionCall,
 26          ToolCall,
 27          UsageInfo,
 28      )
 29  
 30      CHAT_DO_REQUEST_PATH = "mistralai.chat.Chat.do_request"
 31      CHAT_DO_REQUEST_ASYNC_PATH = "mistralai.chat.Chat.do_request_async"
 32  from pydantic import BaseModel
 33  
 34  import mlflow.mistral
 35  from mlflow.entities.span import SpanType
 36  from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey
 37  from mlflow.version import IS_TRACING_SDK_ONLY
 38  
 39  from tests.tracing.helper import get_traces
 40  
 41  DUMMY_CHAT_COMPLETION_REQUEST = {
 42      "model": "test_model",
 43      "max_tokens": 1024,
 44      "messages": [{"role": "user", "content": "test message"}],
 45  }
 46  
 47  DUMMY_CHAT_COMPLETION_RESPONSE = ChatCompletionResponse(
 48      id="test_id",
 49      object="chat.completion",
 50      model="test_model",
 51      usage=UsageInfo(prompt_tokens=10, completion_tokens=18, total_tokens=28),
 52      created=1736200000,
 53      choices=[
 54          ChatCompletionChoice(
 55              index=0,
 56              message=AssistantMessage(
 57                  role="assistant",
 58                  content="test answer",
 59                  prefix=False,
 60                  tool_calls=None,
 61              ),
 62              finish_reason="stop",
 63          )
 64      ],
 65  )
 66  
 67  # Ref: https://docs.mistral.ai/capabilities/function_calling/
 68  DUMMY_CHAT_COMPLETION_WITH_TOOLS_REQUEST = {
 69      "model": "test_model",
 70      "max_tokens": 1024,
 71      "tools": [
 72          {
 73              "type": "function",
 74              "function": {
 75                  "name": "get_unit",
 76                  "description": "Get the temperature unit commonly used in a given location",
 77                  "parameters": {
 78                      "type": "object",
 79                      "properties": {
 80                          "location": {
 81                              "type": "string",
 82                              "description": "The city and state, e.g., San Francisco, CA",
 83                          },
 84                      },
 85                      "required": ["location"],
 86                  },
 87              },
 88          },
 89          {
 90              "type": "function",
 91              "function": {
 92                  "name": "get_weather",
 93                  "description": "Get the current weather in a given location",
 94                  "parameters": {
 95                      "type": "object",
 96                      "properties": {
 97                          "location": {
 98                              "type": "string",
 99                              "description": "The city and state, e.g., San Francisco, CA",
100                          },
101                          "unit": {
102                              "type": "string",
103                              "enum": ["celsius", "fahrenheit"],
104                              "description": 'The unit of temperature, "celsius" or "fahrenheit"',
105                          },
106                      },
107                      "required": ["location", "unit"],
108                  },
109              },
110          },
111      ],
112      "messages": [
113          {"role": "user", "content": "What's the weather like in San Francisco?"},
114          {
115              "role": "assistant",
116              "content": "",
117              "tool_calls": [
118                  {
119                      "function": {"name": "get_unit", "arguments": '{"location": "San Francisco"}'},
120                      "id": "tool_123",
121                      "type": "function",
122                  }
123              ],
124              "prefix": False,
125          },
126          {"role": "tool", "name": "get_unit", "content": "celsius", "tool_call_id": "tool_123"},
127      ],
128  }
129  
130  DUMMY_CHAT_COMPLETION_WITH_TOOLS_RESPONSE = ChatCompletionResponse(
131      id="test_id",
132      object="chat.completion",
133      model="test_model",
134      usage=UsageInfo(prompt_tokens=11, completion_tokens=19, total_tokens=30),
135      created=1736300000,
136      choices=[
137          ChatCompletionChoice(
138              index=0,
139              message=AssistantMessage(
140                  role="assistant",
141                  content="",
142                  prefix=False,
143                  tool_calls=[
144                      ToolCall(
145                          function=FunctionCall(
146                              name="get_weather",
147                              arguments='{"location": "San Francisco", "unit": "celsius"}',
148                          ),
149                          id="tool_456",
150                          type="function",
151                      ),
152                  ],
153              ),
154              finish_reason="tool_calls",
155          )
156      ],
157  )
158  
159  
160  def _make_httpx_response(response: BaseModel, status_code: int = 200) -> httpx.Response:
161      return httpx.Response(
162          status_code=status_code,
163          headers={"Content-Type": "application/json"},
164          text=response.model_dump_json(),
165      )
166  
167  
168  def test_chat_complete_autolog(mock_litellm_cost):
169      with patch(
170          CHAT_DO_REQUEST_PATH,
171          return_value=_make_httpx_response(DUMMY_CHAT_COMPLETION_RESPONSE),
172      ):
173          mlflow.mistral.autolog()
174          client = Mistral(api_key="test_key")
175          client.chat.complete(**DUMMY_CHAT_COMPLETION_REQUEST)
176  
177      traces = get_traces()
178      assert len(traces) == 1
179      assert traces[0].info.status == "OK"
180      assert len(traces[0].data.spans) == 1
181      span = traces[0].data.spans[0]
182      assert span.name == "Chat.complete"
183      assert span.span_type == SpanType.CHAT_MODEL
184      assert span.inputs == DUMMY_CHAT_COMPLETION_REQUEST
185      # Only keep input_tokens / output_tokens fields in usage dict.
186      span.outputs["usage"] = {
187          key: span.outputs["usage"][key]
188          for key in ["prompt_tokens", "completion_tokens", "total_tokens"]
189      }
190      assert span.outputs == DUMMY_CHAT_COMPLETION_RESPONSE.model_dump()
191      assert span.model_name == "test_model"
192      assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "mistral"
193      assert traces[0].info.token_usage == {
194          TokenUsageKey.INPUT_TOKENS: 10,
195          TokenUsageKey.OUTPUT_TOKENS: 18,
196          TokenUsageKey.TOTAL_TOKENS: 28,
197      }
198      if not IS_TRACING_SDK_ONLY:
199          # Verify cost is calculated (10 input tokens * 1.0 + 18 output tokens * 2.0)
200          assert span.llm_cost == {
201              "input_cost": 10.0,
202              "output_cost": 36.0,
203              "total_cost": 46.0,
204          }
205  
206      with patch(
207          CHAT_DO_REQUEST_PATH,
208          return_value=_make_httpx_response(DUMMY_CHAT_COMPLETION_RESPONSE),
209      ):
210          mlflow.mistral.autolog(disable=True)
211          client = Mistral(api_key="test_key")
212          client.chat.complete(**DUMMY_CHAT_COMPLETION_REQUEST)
213  
214      # No new trace should be created
215      traces = get_traces()
216      assert len(traces) == 1
217  
218  
219  def test_chat_complete_autolog_tool_calling():
220      with patch(
221          CHAT_DO_REQUEST_PATH,
222          return_value=_make_httpx_response(DUMMY_CHAT_COMPLETION_WITH_TOOLS_RESPONSE),
223      ):
224          mlflow.mistral.autolog()
225          client = Mistral(api_key="test_key")
226          client.chat.complete(**DUMMY_CHAT_COMPLETION_WITH_TOOLS_REQUEST)
227  
228      traces = get_traces()
229      assert len(traces) == 1
230      assert traces[0].info.status == "OK"
231      assert len(traces[0].data.spans) == 1
232      span = traces[0].data.spans[0]
233      assert span.name == "Chat.complete"
234      assert span.span_type == SpanType.CHAT_MODEL
235      assert span.inputs == DUMMY_CHAT_COMPLETION_WITH_TOOLS_REQUEST
236      assert span.outputs == DUMMY_CHAT_COMPLETION_WITH_TOOLS_RESPONSE.model_dump()
237      assert span.model_name == "test_model"
238  
239      assert span.get_attribute(SpanAttributeKey.CHAT_TOOLS) == [
240          {
241              "type": "function",
242              "function": {
243                  "name": "get_unit",
244                  "description": "Get the temperature unit commonly used in a given location",
245                  "parameters": {
246                      "properties": {
247                          "location": {
248                              "description": "The city and state, e.g., San Francisco, CA",
249                              "type": "string",
250                          },
251                      },
252                      "required": ["location"],
253                      "type": "object",
254                  },
255              },
256          },
257          {
258              "type": "function",
259              "function": {
260                  "name": "get_weather",
261                  "description": "Get the current weather in a given location",
262                  "parameters": {
263                      "properties": {
264                          "location": {
265                              "description": "The city and state, e.g., San Francisco, CA",
266                              "type": "string",
267                          },
268                          "unit": {
269                              "description": 'The unit of temperature, "celsius" or "fahrenheit"',
270                              "enum": ["celsius", "fahrenheit"],
271                              "type": "string",
272                          },
273                      },
274                      "required": ["location", "unit"],
275                      "type": "object",
276                  },
277              },
278          },
279      ]
280      assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "mistral"
281      assert traces[0].info.token_usage == {
282          TokenUsageKey.INPUT_TOKENS: 11,
283          TokenUsageKey.OUTPUT_TOKENS: 19,
284          TokenUsageKey.TOTAL_TOKENS: 30,
285      }
286  
287  
288  @pytest.mark.asyncio
289  async def test_chat_complete_async_autolog():
290      with patch(
291          CHAT_DO_REQUEST_ASYNC_PATH,
292          return_value=_make_httpx_response(DUMMY_CHAT_COMPLETION_RESPONSE),
293      ):
294          mlflow.mistral.autolog()
295          client = Mistral(api_key="test_key")
296          await client.chat.complete_async(**DUMMY_CHAT_COMPLETION_REQUEST)
297  
298      traces = get_traces()
299      assert len(traces) == 1
300      span = traces[0].data.spans[0]
301      assert span.name == "Chat.complete_async"
302      assert span.span_type == SpanType.CHAT_MODEL
303      assert span.inputs == DUMMY_CHAT_COMPLETION_REQUEST
304      span.outputs["usage"] = {
305          key: span.outputs["usage"][key]
306          for key in ["prompt_tokens", "completion_tokens", "total_tokens"]
307      }
308      assert span.outputs == DUMMY_CHAT_COMPLETION_RESPONSE.model_dump()
309      assert span.model_name == "test_model"
310      assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "mistral"
311      assert traces[0].info.token_usage == {
312          TokenUsageKey.INPUT_TOKENS: 10,
313          TokenUsageKey.OUTPUT_TOKENS: 18,
314          TokenUsageKey.TOTAL_TOKENS: 28,
315      }