test_autogen_autolog.py
1 import pytest 2 from autogen_agentchat.agents import AssistantAgent 3 from autogen_agentchat.messages import MultiModalMessage 4 from autogen_core import FunctionCall, Image 5 from autogen_core.models import CreateResult 6 from autogen_ext.models.replay import ReplayChatCompletionClient 7 8 import mlflow 9 from mlflow.entities.span import SpanType 10 from mlflow.tracing.constant import SpanAttributeKey 11 from mlflow.version import IS_TRACING_SDK_ONLY 12 13 from tests.tracing.helper import get_traces 14 15 _SYSTEM_MESSAGE = "You are a helpful assistant." 16 _MODEL_USAGE = {"prompt_tokens": 6, "completion_tokens": 1} 17 18 19 @pytest.mark.asyncio 20 @pytest.mark.parametrize( 21 "disable", 22 [True, False], 23 ) 24 async def test_autolog_assistant_agent(disable, mock_litellm_cost): 25 model_client = ReplayChatCompletionClient( 26 ["2"], 27 ) 28 model_client.model = "gpt-4o-mini" 29 agent = AssistantAgent("assistant", model_client=model_client, system_message=_SYSTEM_MESSAGE) 30 31 mlflow.autogen.autolog(disable=disable) 32 33 await agent.run(task="1+1") 34 35 traces = get_traces() 36 37 if disable: 38 assert len(traces) == 0 39 else: 40 assert len(traces) == 1 41 trace = traces[0] 42 assert trace.info.status == "OK" 43 assert len(trace.data.spans) == 3 44 span = trace.data.spans[0] 45 assert span.name == "assistant.run" 46 assert span.span_type == SpanType.AGENT 47 assert span.inputs == {"task": "1+1"} 48 messages = span.outputs["messages"] 49 assert len(messages) == 2 50 assert ( 51 messages[0].items() 52 >= { 53 "content": "1+1", 54 "source": "user", 55 "models_usage": None, 56 "metadata": {}, 57 "type": "TextMessage", 58 }.items() 59 ) 60 assert ( 61 messages[1].items() 62 >= { 63 "content": "2", 64 "source": "assistant", 65 "models_usage": _MODEL_USAGE, 66 "metadata": {}, 67 "type": "TextMessage", 68 }.items() 69 ) 70 71 span = trace.data.spans[1] 72 assert span.name == "assistant.on_messages" 73 assert span.span_type == SpanType.AGENT 74 assert ( 75 span.outputs["chat_message"].items() 76 >= { 77 "source": "assistant", 78 "models_usage": _MODEL_USAGE, 79 "metadata": {}, 80 "content": "2", 81 "type": "TextMessage", 82 }.items() 83 ) 84 85 span = trace.data.spans[2] 86 assert span.name == "ReplayChatCompletionClient.create" 87 assert span.span_type == SpanType.LLM 88 assert span.inputs["messages"] == [ 89 {"content": _SYSTEM_MESSAGE, "type": "SystemMessage"}, 90 {"content": "1+1", "source": "user", "type": "UserMessage"}, 91 ] 92 assert span.outputs["content"] == "2" 93 assert span.model_name == "gpt-4o-mini" 94 95 assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == { 96 "input_tokens": 6, 97 "output_tokens": 1, 98 "total_tokens": 7, 99 } 100 if not IS_TRACING_SDK_ONLY: 101 # Verify cost is calculated (6 input tokens * 1.0 + 1 output tokens * 2.0) 102 assert span.llm_cost == { 103 "input_cost": 6.0, 104 "output_cost": 2.0, 105 "total_cost": 8.0, 106 } 107 108 assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "autogen" 109 110 assert traces[0].info.token_usage == { 111 "input_tokens": 6, 112 "output_tokens": 1, 113 "total_tokens": 7, 114 } 115 116 117 @pytest.mark.asyncio 118 async def test_autolog_tool_agent(mock_litellm_cost): 119 model_client = ReplayChatCompletionClient( 120 [ 121 CreateResult( 122 content=[FunctionCall(id="1", arguments='{"number": 1}', name="increment_number")], 123 finish_reason="function_calls", 124 usage=_MODEL_USAGE, 125 cached=False, 126 ), 127 ], 128 ) 129 model_client.model = "gpt-4o-mini" 130 model_client.model_info["function_calling"] = True 131 TOOL_ATTRIBUTES = [ 132 { 133 "function": { 134 "name": "increment_number", 135 "description": "Increment a number by 1.", 136 "parameters": { 137 "type": "object", 138 "properties": {"number": {"description": "number", "type": "integer"}}, 139 "required": ["number"], 140 "additionalProperties": False, 141 }, 142 "strict": False, 143 }, 144 "type": "function", 145 } 146 ] 147 148 def increment_number(number: int) -> int: 149 """Increment a number by 1.""" 150 return number + 1 151 152 agent = AssistantAgent( 153 "assistant", 154 model_client=model_client, 155 system_message=_SYSTEM_MESSAGE, 156 tools=[increment_number], 157 ) 158 mlflow.autogen.autolog() 159 160 await agent.run(task="1+1") 161 162 traces = get_traces() 163 assert len(traces) == 1 164 trace = traces[0] 165 assert trace.info.status == "OK" 166 assert len(trace.data.spans) == 3 167 span = trace.data.spans[0] 168 assert span.name == "assistant.run" 169 assert span.span_type == SpanType.AGENT 170 assert span.inputs == {"task": "1+1"} 171 messages = span.outputs["messages"] 172 assert len(messages) == 4 173 assert ( 174 messages[0].items() 175 >= { 176 "content": "1+1", 177 "source": "user", 178 "models_usage": None, 179 "metadata": {}, 180 "type": "TextMessage", 181 }.items() 182 ) 183 184 assert ( 185 messages[1].items() 186 >= { 187 "content": [ 188 { 189 "id": "1", 190 "arguments": '{"number": 1}', 191 "name": "increment_number", 192 } 193 ], 194 "source": "assistant", 195 "models_usage": _MODEL_USAGE, 196 "metadata": {}, 197 "type": "ToolCallRequestEvent", 198 }.items() 199 ) 200 assert ( 201 messages[2].items() 202 >= { 203 "content": [ 204 { 205 "call_id": "1", 206 "content": "2", 207 "is_error": False, 208 "name": "increment_number", 209 } 210 ], 211 "source": "assistant", 212 "models_usage": None, 213 "metadata": {}, 214 "type": "ToolCallExecutionEvent", 215 }.items() 216 ) 217 assert ( 218 messages[3].items() 219 >= { 220 "content": "2", 221 "source": "assistant", 222 "models_usage": None, 223 "metadata": {}, 224 "type": "ToolCallSummaryMessage", 225 }.items() 226 ) 227 228 span = trace.data.spans[1] 229 assert span.name == "assistant.on_messages" 230 assert span.span_type == SpanType.AGENT 231 assert ( 232 span.outputs["chat_message"].items() 233 >= { 234 "source": "assistant", 235 "models_usage": None, 236 "metadata": {}, 237 "content": "2", 238 "type": "ToolCallSummaryMessage", 239 }.items() 240 ) 241 assert span.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTES 242 243 span = trace.data.spans[2] 244 assert span.name == "ReplayChatCompletionClient.create" 245 assert span.span_type == SpanType.LLM 246 assert span.inputs["messages"] == [ 247 {"content": _SYSTEM_MESSAGE, "type": "SystemMessage"}, 248 {"content": "1+1", "source": "user", "type": "UserMessage"}, 249 ] 250 assert span.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTES 251 assert span.outputs["content"] == [ 252 {"id": "1", "arguments": '{"number": 1}', "name": "increment_number"} 253 ] 254 assert span.model_name == "gpt-4o-mini" 255 256 assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == { 257 "input_tokens": 6, 258 "output_tokens": 1, 259 "total_tokens": 7, 260 } 261 if not IS_TRACING_SDK_ONLY: 262 assert span.llm_cost == { 263 "input_cost": 6.0, 264 "output_cost": 2.0, 265 "total_cost": 8.0, 266 } 267 268 assert traces[0].info.token_usage == { 269 "input_tokens": 6, 270 "output_tokens": 1, 271 "total_tokens": 7, 272 } 273 274 275 @pytest.mark.asyncio 276 async def test_autolog_multi_modal(mock_litellm_cost): 277 import PIL 278 279 pil_image = PIL.Image.new("RGB", (8, 8)) 280 img = Image(pil_image) 281 user_message = "Can you describe the number in the image?" 282 multi_modal_message = MultiModalMessage(content=[user_message, img], source="user") 283 model_client = ReplayChatCompletionClient( 284 ["2"], 285 ) 286 model_client.model = "gpt-4o-mini" 287 agent = AssistantAgent("assistant", model_client=model_client, system_message=_SYSTEM_MESSAGE) 288 mlflow.autogen.autolog() 289 290 await agent.run(task=multi_modal_message) 291 292 traces = get_traces() 293 294 assert len(traces) == 1 295 trace = traces[0] 296 assert trace.info.status == "OK" 297 assert len(trace.data.spans) == 3 298 span = trace.data.spans[0] 299 assert span.name == "assistant.run" 300 assert span.span_type == SpanType.AGENT 301 assert span.inputs["task"]["content"][0] == "Can you describe the number in the image?" 302 assert "data" in span.inputs["task"]["content"][1] 303 messages = span.outputs["messages"] 304 assert len(messages) == 2 305 assert ( 306 messages[0].items() 307 >= { 308 "content": [ 309 "Can you describe the number in the image?", 310 { 311 "data": "iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAADElEQVR4nGNgGB4AAADIAAGtQHYiAAAAAElFTkSuQmCC", # noqa: E501 312 }, 313 ], 314 "source": "user", 315 "models_usage": None, 316 "metadata": {}, 317 "type": "MultiModalMessage", 318 }.items() 319 ) 320 assert ( 321 messages[1].items() 322 >= { 323 "content": "2", 324 "source": "assistant", 325 "models_usage": {"completion_tokens": 1, "prompt_tokens": 14}, 326 "metadata": {}, 327 "type": "TextMessage", 328 }.items() 329 ) 330 331 span = trace.data.spans[1] 332 assert span.name == "assistant.on_messages" 333 assert span.span_type == SpanType.AGENT 334 assert ( 335 span.outputs["chat_message"].items() 336 >= { 337 "source": "assistant", 338 "models_usage": {"completion_tokens": 1, "prompt_tokens": 14}, 339 "metadata": {}, 340 "content": "2", 341 "type": "TextMessage", 342 }.items() 343 ) 344 345 span = trace.data.spans[2] 346 assert span.name == "ReplayChatCompletionClient.create" 347 assert span.span_type == SpanType.LLM 348 assert span.inputs["messages"] == [ 349 {"content": _SYSTEM_MESSAGE, "type": "SystemMessage"}, 350 {"content": f"{user_message}\n<image>", "source": "user", "type": "UserMessage"}, 351 ] 352 assert span.outputs["content"] == "2" 353 assert span.model_name == "gpt-4o-mini" 354 355 assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == { 356 "input_tokens": 14, 357 "output_tokens": 1, 358 "total_tokens": 15, 359 } 360 if not IS_TRACING_SDK_ONLY: 361 assert span.llm_cost == { 362 "input_cost": 14.0, 363 "output_cost": 2.0, 364 "total_cost": 16.0, 365 } 366 367 assert traces[0].info.token_usage == { 368 "input_tokens": 14, 369 "output_tokens": 1, 370 "total_tokens": 15, 371 }