test_mistral_autolog.py
1 from unittest.mock import patch 2 3 import httpx 4 import pytest 5 6 try: 7 from mistralai.client import Mistral # mistralai >= 2.0 8 from mistralai.client.models import ( 9 AssistantMessage, 10 ChatCompletionChoice, 11 ChatCompletionResponse, 12 FunctionCall, 13 ToolCall, 14 UsageInfo, 15 ) 16 17 CHAT_DO_REQUEST_PATH = "mistralai.client.chat.Chat.do_request" 18 CHAT_DO_REQUEST_ASYNC_PATH = "mistralai.client.chat.Chat.do_request_async" 19 except ImportError: 20 from mistralai import Mistral # mistralai < 2.0 21 from mistralai.models import ( 22 AssistantMessage, 23 ChatCompletionChoice, 24 ChatCompletionResponse, 25 FunctionCall, 26 ToolCall, 27 UsageInfo, 28 ) 29 30 CHAT_DO_REQUEST_PATH = "mistralai.chat.Chat.do_request" 31 CHAT_DO_REQUEST_ASYNC_PATH = "mistralai.chat.Chat.do_request_async" 32 from pydantic import BaseModel 33 34 import mlflow.mistral 35 from mlflow.entities.span import SpanType 36 from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey 37 from mlflow.version import IS_TRACING_SDK_ONLY 38 39 from tests.tracing.helper import get_traces 40 41 DUMMY_CHAT_COMPLETION_REQUEST = { 42 "model": "test_model", 43 "max_tokens": 1024, 44 "messages": [{"role": "user", "content": "test message"}], 45 } 46 47 DUMMY_CHAT_COMPLETION_RESPONSE = ChatCompletionResponse( 48 id="test_id", 49 object="chat.completion", 50 model="test_model", 51 usage=UsageInfo(prompt_tokens=10, completion_tokens=18, total_tokens=28), 52 created=1736200000, 53 choices=[ 54 ChatCompletionChoice( 55 index=0, 56 message=AssistantMessage( 57 role="assistant", 58 content="test answer", 59 prefix=False, 60 tool_calls=None, 61 ), 62 finish_reason="stop", 63 ) 64 ], 65 ) 66 67 # Ref: https://docs.mistral.ai/capabilities/function_calling/ 68 DUMMY_CHAT_COMPLETION_WITH_TOOLS_REQUEST = { 69 "model": "test_model", 70 "max_tokens": 1024, 71 "tools": [ 72 { 73 "type": "function", 74 "function": { 75 "name": "get_unit", 76 "description": "Get the temperature unit commonly used in a given location", 77 "parameters": { 78 "type": "object", 79 "properties": { 80 "location": { 81 "type": "string", 82 "description": "The city and state, e.g., San Francisco, CA", 83 }, 84 }, 85 "required": ["location"], 86 }, 87 }, 88 }, 89 { 90 "type": "function", 91 "function": { 92 "name": "get_weather", 93 "description": "Get the current weather in a given location", 94 "parameters": { 95 "type": "object", 96 "properties": { 97 "location": { 98 "type": "string", 99 "description": "The city and state, e.g., San Francisco, CA", 100 }, 101 "unit": { 102 "type": "string", 103 "enum": ["celsius", "fahrenheit"], 104 "description": 'The unit of temperature, "celsius" or "fahrenheit"', 105 }, 106 }, 107 "required": ["location", "unit"], 108 }, 109 }, 110 }, 111 ], 112 "messages": [ 113 {"role": "user", "content": "What's the weather like in San Francisco?"}, 114 { 115 "role": "assistant", 116 "content": "", 117 "tool_calls": [ 118 { 119 "function": {"name": "get_unit", "arguments": '{"location": "San Francisco"}'}, 120 "id": "tool_123", 121 "type": "function", 122 } 123 ], 124 "prefix": False, 125 }, 126 {"role": "tool", "name": "get_unit", "content": "celsius", "tool_call_id": "tool_123"}, 127 ], 128 } 129 130 DUMMY_CHAT_COMPLETION_WITH_TOOLS_RESPONSE = ChatCompletionResponse( 131 id="test_id", 132 object="chat.completion", 133 model="test_model", 134 usage=UsageInfo(prompt_tokens=11, completion_tokens=19, total_tokens=30), 135 created=1736300000, 136 choices=[ 137 ChatCompletionChoice( 138 index=0, 139 message=AssistantMessage( 140 role="assistant", 141 content="", 142 prefix=False, 143 tool_calls=[ 144 ToolCall( 145 function=FunctionCall( 146 name="get_weather", 147 arguments='{"location": "San Francisco", "unit": "celsius"}', 148 ), 149 id="tool_456", 150 type="function", 151 ), 152 ], 153 ), 154 finish_reason="tool_calls", 155 ) 156 ], 157 ) 158 159 160 def _make_httpx_response(response: BaseModel, status_code: int = 200) -> httpx.Response: 161 return httpx.Response( 162 status_code=status_code, 163 headers={"Content-Type": "application/json"}, 164 text=response.model_dump_json(), 165 ) 166 167 168 def test_chat_complete_autolog(mock_litellm_cost): 169 with patch( 170 CHAT_DO_REQUEST_PATH, 171 return_value=_make_httpx_response(DUMMY_CHAT_COMPLETION_RESPONSE), 172 ): 173 mlflow.mistral.autolog() 174 client = Mistral(api_key="test_key") 175 client.chat.complete(**DUMMY_CHAT_COMPLETION_REQUEST) 176 177 traces = get_traces() 178 assert len(traces) == 1 179 assert traces[0].info.status == "OK" 180 assert len(traces[0].data.spans) == 1 181 span = traces[0].data.spans[0] 182 assert span.name == "Chat.complete" 183 assert span.span_type == SpanType.CHAT_MODEL 184 assert span.inputs == DUMMY_CHAT_COMPLETION_REQUEST 185 # Only keep input_tokens / output_tokens fields in usage dict. 186 span.outputs["usage"] = { 187 key: span.outputs["usage"][key] 188 for key in ["prompt_tokens", "completion_tokens", "total_tokens"] 189 } 190 assert span.outputs == DUMMY_CHAT_COMPLETION_RESPONSE.model_dump() 191 assert span.model_name == "test_model" 192 assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "mistral" 193 assert traces[0].info.token_usage == { 194 TokenUsageKey.INPUT_TOKENS: 10, 195 TokenUsageKey.OUTPUT_TOKENS: 18, 196 TokenUsageKey.TOTAL_TOKENS: 28, 197 } 198 if not IS_TRACING_SDK_ONLY: 199 # Verify cost is calculated (10 input tokens * 1.0 + 18 output tokens * 2.0) 200 assert span.llm_cost == { 201 "input_cost": 10.0, 202 "output_cost": 36.0, 203 "total_cost": 46.0, 204 } 205 206 with patch( 207 CHAT_DO_REQUEST_PATH, 208 return_value=_make_httpx_response(DUMMY_CHAT_COMPLETION_RESPONSE), 209 ): 210 mlflow.mistral.autolog(disable=True) 211 client = Mistral(api_key="test_key") 212 client.chat.complete(**DUMMY_CHAT_COMPLETION_REQUEST) 213 214 # No new trace should be created 215 traces = get_traces() 216 assert len(traces) == 1 217 218 219 def test_chat_complete_autolog_tool_calling(): 220 with patch( 221 CHAT_DO_REQUEST_PATH, 222 return_value=_make_httpx_response(DUMMY_CHAT_COMPLETION_WITH_TOOLS_RESPONSE), 223 ): 224 mlflow.mistral.autolog() 225 client = Mistral(api_key="test_key") 226 client.chat.complete(**DUMMY_CHAT_COMPLETION_WITH_TOOLS_REQUEST) 227 228 traces = get_traces() 229 assert len(traces) == 1 230 assert traces[0].info.status == "OK" 231 assert len(traces[0].data.spans) == 1 232 span = traces[0].data.spans[0] 233 assert span.name == "Chat.complete" 234 assert span.span_type == SpanType.CHAT_MODEL 235 assert span.inputs == DUMMY_CHAT_COMPLETION_WITH_TOOLS_REQUEST 236 assert span.outputs == DUMMY_CHAT_COMPLETION_WITH_TOOLS_RESPONSE.model_dump() 237 assert span.model_name == "test_model" 238 239 assert span.get_attribute(SpanAttributeKey.CHAT_TOOLS) == [ 240 { 241 "type": "function", 242 "function": { 243 "name": "get_unit", 244 "description": "Get the temperature unit commonly used in a given location", 245 "parameters": { 246 "properties": { 247 "location": { 248 "description": "The city and state, e.g., San Francisco, CA", 249 "type": "string", 250 }, 251 }, 252 "required": ["location"], 253 "type": "object", 254 }, 255 }, 256 }, 257 { 258 "type": "function", 259 "function": { 260 "name": "get_weather", 261 "description": "Get the current weather in a given location", 262 "parameters": { 263 "properties": { 264 "location": { 265 "description": "The city and state, e.g., San Francisco, CA", 266 "type": "string", 267 }, 268 "unit": { 269 "description": 'The unit of temperature, "celsius" or "fahrenheit"', 270 "enum": ["celsius", "fahrenheit"], 271 "type": "string", 272 }, 273 }, 274 "required": ["location", "unit"], 275 "type": "object", 276 }, 277 }, 278 }, 279 ] 280 assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "mistral" 281 assert traces[0].info.token_usage == { 282 TokenUsageKey.INPUT_TOKENS: 11, 283 TokenUsageKey.OUTPUT_TOKENS: 19, 284 TokenUsageKey.TOTAL_TOKENS: 30, 285 } 286 287 288 @pytest.mark.asyncio 289 async def test_chat_complete_async_autolog(): 290 with patch( 291 CHAT_DO_REQUEST_ASYNC_PATH, 292 return_value=_make_httpx_response(DUMMY_CHAT_COMPLETION_RESPONSE), 293 ): 294 mlflow.mistral.autolog() 295 client = Mistral(api_key="test_key") 296 await client.chat.complete_async(**DUMMY_CHAT_COMPLETION_REQUEST) 297 298 traces = get_traces() 299 assert len(traces) == 1 300 span = traces[0].data.spans[0] 301 assert span.name == "Chat.complete_async" 302 assert span.span_type == SpanType.CHAT_MODEL 303 assert span.inputs == DUMMY_CHAT_COMPLETION_REQUEST 304 span.outputs["usage"] = { 305 key: span.outputs["usage"][key] 306 for key in ["prompt_tokens", "completion_tokens", "total_tokens"] 307 } 308 assert span.outputs == DUMMY_CHAT_COMPLETION_RESPONSE.model_dump() 309 assert span.model_name == "test_model" 310 assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "mistral" 311 assert traces[0].info.token_usage == { 312 TokenUsageKey.INPUT_TOKENS: 10, 313 TokenUsageKey.OUTPUT_TOKENS: 18, 314 TokenUsageKey.TOTAL_TOKENS: 28, 315 }