/ tests / autogen / test_autogen_autolog.py
test_autogen_autolog.py
  1  import pytest
  2  from autogen_agentchat.agents import AssistantAgent
  3  from autogen_agentchat.messages import MultiModalMessage
  4  from autogen_core import FunctionCall, Image
  5  from autogen_core.models import CreateResult
  6  from autogen_ext.models.replay import ReplayChatCompletionClient
  7  
  8  import mlflow
  9  from mlflow.entities.span import SpanType
 10  from mlflow.tracing.constant import SpanAttributeKey
 11  from mlflow.version import IS_TRACING_SDK_ONLY
 12  
 13  from tests.tracing.helper import get_traces
 14  
 15  _SYSTEM_MESSAGE = "You are a helpful assistant."
 16  _MODEL_USAGE = {"prompt_tokens": 6, "completion_tokens": 1}
 17  
 18  
 19  @pytest.mark.asyncio
 20  @pytest.mark.parametrize(
 21      "disable",
 22      [True, False],
 23  )
 24  async def test_autolog_assistant_agent(disable, mock_litellm_cost):
 25      model_client = ReplayChatCompletionClient(
 26          ["2"],
 27      )
 28      model_client.model = "gpt-4o-mini"
 29      agent = AssistantAgent("assistant", model_client=model_client, system_message=_SYSTEM_MESSAGE)
 30  
 31      mlflow.autogen.autolog(disable=disable)
 32  
 33      await agent.run(task="1+1")
 34  
 35      traces = get_traces()
 36  
 37      if disable:
 38          assert len(traces) == 0
 39      else:
 40          assert len(traces) == 1
 41          trace = traces[0]
 42          assert trace.info.status == "OK"
 43          assert len(trace.data.spans) == 3
 44          span = trace.data.spans[0]
 45          assert span.name == "assistant.run"
 46          assert span.span_type == SpanType.AGENT
 47          assert span.inputs == {"task": "1+1"}
 48          messages = span.outputs["messages"]
 49          assert len(messages) == 2
 50          assert (
 51              messages[0].items()
 52              >= {
 53                  "content": "1+1",
 54                  "source": "user",
 55                  "models_usage": None,
 56                  "metadata": {},
 57                  "type": "TextMessage",
 58              }.items()
 59          )
 60          assert (
 61              messages[1].items()
 62              >= {
 63                  "content": "2",
 64                  "source": "assistant",
 65                  "models_usage": _MODEL_USAGE,
 66                  "metadata": {},
 67                  "type": "TextMessage",
 68              }.items()
 69          )
 70  
 71          span = trace.data.spans[1]
 72          assert span.name == "assistant.on_messages"
 73          assert span.span_type == SpanType.AGENT
 74          assert (
 75              span.outputs["chat_message"].items()
 76              >= {
 77                  "source": "assistant",
 78                  "models_usage": _MODEL_USAGE,
 79                  "metadata": {},
 80                  "content": "2",
 81                  "type": "TextMessage",
 82              }.items()
 83          )
 84  
 85          span = trace.data.spans[2]
 86          assert span.name == "ReplayChatCompletionClient.create"
 87          assert span.span_type == SpanType.LLM
 88          assert span.inputs["messages"] == [
 89              {"content": _SYSTEM_MESSAGE, "type": "SystemMessage"},
 90              {"content": "1+1", "source": "user", "type": "UserMessage"},
 91          ]
 92          assert span.outputs["content"] == "2"
 93          assert span.model_name == "gpt-4o-mini"
 94  
 95          assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
 96              "input_tokens": 6,
 97              "output_tokens": 1,
 98              "total_tokens": 7,
 99          }
100          if not IS_TRACING_SDK_ONLY:
101              # Verify cost is calculated (6 input tokens * 1.0 + 1 output tokens * 2.0)
102              assert span.llm_cost == {
103                  "input_cost": 6.0,
104                  "output_cost": 2.0,
105                  "total_cost": 8.0,
106              }
107  
108          assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "autogen"
109  
110          assert traces[0].info.token_usage == {
111              "input_tokens": 6,
112              "output_tokens": 1,
113              "total_tokens": 7,
114          }
115  
116  
117  @pytest.mark.asyncio
118  async def test_autolog_tool_agent(mock_litellm_cost):
119      model_client = ReplayChatCompletionClient(
120          [
121              CreateResult(
122                  content=[FunctionCall(id="1", arguments='{"number": 1}', name="increment_number")],
123                  finish_reason="function_calls",
124                  usage=_MODEL_USAGE,
125                  cached=False,
126              ),
127          ],
128      )
129      model_client.model = "gpt-4o-mini"
130      model_client.model_info["function_calling"] = True
131      TOOL_ATTRIBUTES = [
132          {
133              "function": {
134                  "name": "increment_number",
135                  "description": "Increment a number by 1.",
136                  "parameters": {
137                      "type": "object",
138                      "properties": {"number": {"description": "number", "type": "integer"}},
139                      "required": ["number"],
140                      "additionalProperties": False,
141                  },
142                  "strict": False,
143              },
144              "type": "function",
145          }
146      ]
147  
148      def increment_number(number: int) -> int:
149          """Increment a number by 1."""
150          return number + 1
151  
152      agent = AssistantAgent(
153          "assistant",
154          model_client=model_client,
155          system_message=_SYSTEM_MESSAGE,
156          tools=[increment_number],
157      )
158      mlflow.autogen.autolog()
159  
160      await agent.run(task="1+1")
161  
162      traces = get_traces()
163      assert len(traces) == 1
164      trace = traces[0]
165      assert trace.info.status == "OK"
166      assert len(trace.data.spans) == 3
167      span = trace.data.spans[0]
168      assert span.name == "assistant.run"
169      assert span.span_type == SpanType.AGENT
170      assert span.inputs == {"task": "1+1"}
171      messages = span.outputs["messages"]
172      assert len(messages) == 4
173      assert (
174          messages[0].items()
175          >= {
176              "content": "1+1",
177              "source": "user",
178              "models_usage": None,
179              "metadata": {},
180              "type": "TextMessage",
181          }.items()
182      )
183  
184      assert (
185          messages[1].items()
186          >= {
187              "content": [
188                  {
189                      "id": "1",
190                      "arguments": '{"number": 1}',
191                      "name": "increment_number",
192                  }
193              ],
194              "source": "assistant",
195              "models_usage": _MODEL_USAGE,
196              "metadata": {},
197              "type": "ToolCallRequestEvent",
198          }.items()
199      )
200      assert (
201          messages[2].items()
202          >= {
203              "content": [
204                  {
205                      "call_id": "1",
206                      "content": "2",
207                      "is_error": False,
208                      "name": "increment_number",
209                  }
210              ],
211              "source": "assistant",
212              "models_usage": None,
213              "metadata": {},
214              "type": "ToolCallExecutionEvent",
215          }.items()
216      )
217      assert (
218          messages[3].items()
219          >= {
220              "content": "2",
221              "source": "assistant",
222              "models_usage": None,
223              "metadata": {},
224              "type": "ToolCallSummaryMessage",
225          }.items()
226      )
227  
228      span = trace.data.spans[1]
229      assert span.name == "assistant.on_messages"
230      assert span.span_type == SpanType.AGENT
231      assert (
232          span.outputs["chat_message"].items()
233          >= {
234              "source": "assistant",
235              "models_usage": None,
236              "metadata": {},
237              "content": "2",
238              "type": "ToolCallSummaryMessage",
239          }.items()
240      )
241      assert span.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTES
242  
243      span = trace.data.spans[2]
244      assert span.name == "ReplayChatCompletionClient.create"
245      assert span.span_type == SpanType.LLM
246      assert span.inputs["messages"] == [
247          {"content": _SYSTEM_MESSAGE, "type": "SystemMessage"},
248          {"content": "1+1", "source": "user", "type": "UserMessage"},
249      ]
250      assert span.get_attribute("mlflow.chat.tools") == TOOL_ATTRIBUTES
251      assert span.outputs["content"] == [
252          {"id": "1", "arguments": '{"number": 1}', "name": "increment_number"}
253      ]
254      assert span.model_name == "gpt-4o-mini"
255  
256      assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
257          "input_tokens": 6,
258          "output_tokens": 1,
259          "total_tokens": 7,
260      }
261      if not IS_TRACING_SDK_ONLY:
262          assert span.llm_cost == {
263              "input_cost": 6.0,
264              "output_cost": 2.0,
265              "total_cost": 8.0,
266          }
267  
268      assert traces[0].info.token_usage == {
269          "input_tokens": 6,
270          "output_tokens": 1,
271          "total_tokens": 7,
272      }
273  
274  
275  @pytest.mark.asyncio
276  async def test_autolog_multi_modal(mock_litellm_cost):
277      import PIL
278  
279      pil_image = PIL.Image.new("RGB", (8, 8))
280      img = Image(pil_image)
281      user_message = "Can you describe the number in the image?"
282      multi_modal_message = MultiModalMessage(content=[user_message, img], source="user")
283      model_client = ReplayChatCompletionClient(
284          ["2"],
285      )
286      model_client.model = "gpt-4o-mini"
287      agent = AssistantAgent("assistant", model_client=model_client, system_message=_SYSTEM_MESSAGE)
288      mlflow.autogen.autolog()
289  
290      await agent.run(task=multi_modal_message)
291  
292      traces = get_traces()
293  
294      assert len(traces) == 1
295      trace = traces[0]
296      assert trace.info.status == "OK"
297      assert len(trace.data.spans) == 3
298      span = trace.data.spans[0]
299      assert span.name == "assistant.run"
300      assert span.span_type == SpanType.AGENT
301      assert span.inputs["task"]["content"][0] == "Can you describe the number in the image?"
302      assert "data" in span.inputs["task"]["content"][1]
303      messages = span.outputs["messages"]
304      assert len(messages) == 2
305      assert (
306          messages[0].items()
307          >= {
308              "content": [
309                  "Can you describe the number in the image?",
310                  {
311                      "data": "iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAADElEQVR4nGNgGB4AAADIAAGtQHYiAAAAAElFTkSuQmCC",  # noqa: E501
312                  },
313              ],
314              "source": "user",
315              "models_usage": None,
316              "metadata": {},
317              "type": "MultiModalMessage",
318          }.items()
319      )
320      assert (
321          messages[1].items()
322          >= {
323              "content": "2",
324              "source": "assistant",
325              "models_usage": {"completion_tokens": 1, "prompt_tokens": 14},
326              "metadata": {},
327              "type": "TextMessage",
328          }.items()
329      )
330  
331      span = trace.data.spans[1]
332      assert span.name == "assistant.on_messages"
333      assert span.span_type == SpanType.AGENT
334      assert (
335          span.outputs["chat_message"].items()
336          >= {
337              "source": "assistant",
338              "models_usage": {"completion_tokens": 1, "prompt_tokens": 14},
339              "metadata": {},
340              "content": "2",
341              "type": "TextMessage",
342          }.items()
343      )
344  
345      span = trace.data.spans[2]
346      assert span.name == "ReplayChatCompletionClient.create"
347      assert span.span_type == SpanType.LLM
348      assert span.inputs["messages"] == [
349          {"content": _SYSTEM_MESSAGE, "type": "SystemMessage"},
350          {"content": f"{user_message}\n<image>", "source": "user", "type": "UserMessage"},
351      ]
352      assert span.outputs["content"] == "2"
353      assert span.model_name == "gpt-4o-mini"
354  
355      assert span.get_attribute(SpanAttributeKey.CHAT_USAGE) == {
356          "input_tokens": 14,
357          "output_tokens": 1,
358          "total_tokens": 15,
359      }
360      if not IS_TRACING_SDK_ONLY:
361          assert span.llm_cost == {
362              "input_cost": 14.0,
363              "output_cost": 2.0,
364              "total_cost": 16.0,
365          }
366  
367      assert traces[0].info.token_usage == {
368          "input_tokens": 14,
369          "output_tokens": 1,
370          "total_tokens": 15,
371      }