/ tests / dspy / test_dspy_autolog.py
test_dspy_autolog.py
   1  import importlib
   2  import json
   3  import time
   4  from unittest import mock
   5  
   6  import dspy
   7  import dspy.teleprompt
   8  import pytest
   9  from dspy.evaluate import Evaluate
  10  from dspy.evaluate.metrics import answer_exact_match
  11  from dspy.predict import Predict
  12  from dspy.primitives.example import Example
  13  from dspy.teleprompt import BootstrapFewShot
  14  from dspy.utils.callback import BaseCallback, with_callbacks
  15  from dspy.utils.dummies import DummyLM
  16  from packaging.version import Version
  17  
  18  import mlflow
  19  from mlflow.entities import Feedback, LoggedModelOutput, SpanType, Trace
  20  from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey
  21  from mlflow.version import IS_TRACING_SDK_ONLY
  22  
  23  from tests.tracing.helper import get_traces, score_in_model_serving, skip_when_testing_trace_sdk
  24  
  25  if not IS_TRACING_SDK_ONLY:
  26      from mlflow.tracking import MlflowClient
  27  
  28  
  29  _DSPY_VERSION = Version(importlib.metadata.version("dspy"))
  30  
  31  _DSPY_UNDER_2_6 = _DSPY_VERSION < Version("2.6.0rc1")
  32  
  33  _DSPY_3_0_4_OR_NEWER = _DSPY_VERSION >= Version("3.0.4")
  34  
  35  
  36  # Test module
  37  class CoT(dspy.Module):
  38      def __init__(self):
  39          super().__init__()
  40          self.prog = dspy.ChainOfThought("question -> answer")
  41          mlflow.models.set_retriever_schema(
  42              primary_key="id",
  43              text_column="text",
  44              doc_uri="source",
  45          )
  46  
  47      def forward(self, question):
  48          return self.prog(question=question)
  49  
  50  
  51  class DummyLMWithUsage(DummyLM):
  52      # Usage tracking had an issue before3.0.4
  53      # and DummyLM.__call__ cannot be overridden in 2.5.x
  54      if _DSPY_3_0_4_OR_NEWER:
  55  
  56          def __call__(self, prompt=None, messages=None, **kwargs):
  57              if dspy.settings.usage_tracker:
  58                  dspy.settings.usage_tracker.add_usage(
  59                      "openai/gpt-4.1",
  60                      {
  61                          "prompt_tokens": 5,
  62                          "completion_tokens": 7,
  63                          "total_tokens": 12,
  64                      },
  65                  )
  66  
  67              return super().__call__(prompt, messages, **kwargs)
  68  
  69  
  70  def test_autolog_lm():
  71      mlflow.dspy.autolog()
  72  
  73      lm = DummyLMWithUsage([{"output": "test output"}])
  74      result = lm("test input")
  75      assert result == ["[[ ## output ## ]]\ntest output"]
  76  
  77      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
  78      assert trace is not None
  79      assert trace.info.status == "OK"
  80      # Latency of LM is too small to get > 0 milliseconds difference
  81      assert trace.info.execution_time_ms is not None
  82  
  83      spans = trace.data.spans
  84      assert len(spans) == 1
  85      assert spans[0].name == "DummyLMWithUsage.__call__"
  86      assert spans[0].span_type == SpanType.CHAT_MODEL
  87      assert spans[0].status.status_code == "OK"
  88      assert spans[0].inputs["prompt"] == "test input"
  89      assert spans[0].outputs == ["[[ ## output ## ]]\ntest output"]
  90      assert spans[0].attributes["model"] == "dummy"
  91      assert spans[0].attributes["model_type"] == "chat"
  92      assert spans[0].attributes["temperature"] == 0.0
  93      assert spans[0].attributes["max_tokens"] == 1000
  94      assert spans[0].model_name == "dummy"
  95  
  96  
  97  def test_autolog_cot():
  98      mlflow.dspy.autolog()
  99  
 100      dspy.settings.configure(
 101          lm=DummyLMWithUsage({
 102              "How are you?": {"answer": "test output", "reasoning": "No more responses"}
 103          })
 104      )
 105  
 106      cot = dspy.ChainOfThought("question -> answer", n=3)
 107  
 108      result = cot(question="How are you?")
 109      assert result["answer"] == "test output"
 110      assert result["reasoning"] == "No more responses"
 111  
 112      traces = get_traces()
 113      assert len(traces) == 1
 114      assert traces[0] is not None
 115      assert traces[0].info.status == "OK"
 116      assert traces[0].info.execution_time_ms > 0
 117      if _DSPY_3_0_4_OR_NEWER:
 118          assert traces[0].info.token_usage == {
 119              TokenUsageKey.INPUT_TOKENS: 5,
 120              TokenUsageKey.OUTPUT_TOKENS: 7,
 121              TokenUsageKey.TOTAL_TOKENS: 12,
 122          }
 123  
 124      spans = traces[0].data.spans
 125      assert len(spans) == 7
 126      assert spans[0].name == "ChainOfThought.forward"
 127      assert spans[0].span_type == SpanType.CHAIN
 128      assert spans[0].status.status_code == "OK"
 129      assert spans[0].inputs == {"question": "How are you?"}
 130      assert spans[0].outputs == {"answer": "test output", "reasoning": "No more responses"}
 131      assert (
 132          spans[0].attributes["signature"] == "question -> answer"
 133          if _DSPY_UNDER_2_6
 134          else "question -> reasoning, answer"
 135      )
 136      if _DSPY_3_0_4_OR_NEWER:
 137          assert spans[0].attributes[SpanAttributeKey.CHAT_USAGE] == {
 138              TokenUsageKey.INPUT_TOKENS: 5,
 139              TokenUsageKey.OUTPUT_TOKENS: 7,
 140              TokenUsageKey.TOTAL_TOKENS: 12,
 141          }
 142      assert spans[1].name == "Predict.forward"
 143      assert spans[1].span_type == SpanType.LLM
 144      assert spans[1].inputs["question"] == "How are you?"
 145      assert spans[1].outputs == {"answer": "test output", "reasoning": "No more responses"}
 146      assert spans[2].name == "ChatAdapter.format"
 147      assert spans[2].span_type == SpanType.PARSER
 148      assert spans[2].inputs == {
 149          "inputs": {"question": "How are you?"},
 150          "demos": mock.ANY,
 151          "signature": mock.ANY,
 152      }
 153      assert spans[3].name == "DummyLMWithUsage.__call__"
 154      assert spans[3].span_type == SpanType.CHAT_MODEL
 155      assert spans[3].inputs == {
 156          "prompt": None,
 157          "messages": mock.ANY,
 158          "n": 3,
 159          "temperature": 0.7,
 160      }
 161      assert len(spans[3].outputs) == 3
 162      assert spans[3].model_name == "dummy"
 163      # Output parser will run per completion output (n=3)
 164      for i in range(3):
 165          assert spans[4 + i].name == "ChatAdapter.parse"
 166          assert spans[4 + i].span_type == SpanType.PARSER
 167          assert "question -> reasoning, answer" in spans[4 + i].inputs["signature"]
 168  
 169  
 170  def test_mlflow_callback_exception():
 171      from litellm import ContextWindowExceededError
 172  
 173      mlflow.dspy.autolog()
 174  
 175      class ErrorLM(dspy.LM):
 176          @with_callbacks
 177          def __call__(self, prompt=None, messages=None, **kwargs):
 178              time.sleep(0.1)
 179              # pdpy.ChatAdapter falls back to JSONAdapter unless it's not ContextWindowExceededError
 180              raise ContextWindowExceededError("Error", "invalid model", "provider")
 181  
 182      cot = dspy.ChainOfThought("question -> answer", n=3)
 183  
 184      with dspy.context(
 185          lm=ErrorLM(
 186              model="invalid",
 187              prompt={"How are you?": {"answer": "test output", "reasoning": "No more responses"}},
 188          ),
 189      ):
 190          with pytest.raises(ContextWindowExceededError, match="Error"):
 191              cot(question="How are you?")
 192  
 193      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 194      assert trace is not None
 195      assert trace.info.status == "ERROR"
 196      assert trace.info.execution_time_ms > 0
 197  
 198      spans = trace.data.spans
 199      assert len(spans) == 4
 200      assert spans[0].name == "ChainOfThought.forward"
 201      assert spans[0].inputs == {"question": "How are you?"}
 202      assert spans[0].outputs is None
 203      assert spans[0].status.status_code == "ERROR"
 204      assert spans[1].name == "Predict.forward"
 205      assert spans[1].status.status_code == "ERROR"
 206      assert spans[2].name == "ChatAdapter.format"
 207      assert spans[2].status.status_code == "OK"
 208      assert spans[3].name == "ErrorLM.__call__"
 209      assert spans[3].status.status_code == "ERROR"
 210  
 211  
 212  @pytest.mark.skipif(
 213      _DSPY_VERSION < Version("2.5.42"),
 214      reason="DSPy callback does not handle Tool in versions < 2.5.42",
 215  )
 216  def test_autolog_react():
 217      mlflow.dspy.autolog()
 218  
 219      dspy.settings.configure(
 220          lm=DummyLMWithUsage([
 221              {
 222                  "next_thought": "I need to search for the highest mountain in the world",
 223                  "next_tool_name": "search",
 224                  "next_tool_args": {"query": "Highest mountain in the world"},
 225              },
 226              {
 227                  "next_thought": "I found the highest mountain in the world",
 228                  "next_tool_name": "finish",
 229                  "next_tool_args": {"answer": "Mount Everest"},
 230              },
 231              {
 232                  "answer": "Mount Everest",
 233                  "reasoning": "No more responses",
 234              },
 235          ]),
 236          adapter=dspy.ChatAdapter(),
 237      )
 238  
 239      def search(query: str) -> list[str]:
 240          return "Mount Everest"
 241  
 242      tools = [dspy.Tool(search)]
 243      react = dspy.ReAct("question -> answer", tools=tools)
 244      result = react(question="What is the highest mountain in the world?")
 245      assert result["answer"] == "Mount Everest"
 246  
 247      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 248      assert trace is not None
 249      assert trace.info.status == "OK"
 250      assert trace.info.execution_time_ms > 0
 251      if _DSPY_3_0_4_OR_NEWER:
 252          assert trace.info.token_usage == {
 253              TokenUsageKey.INPUT_TOKENS: 15,
 254              TokenUsageKey.OUTPUT_TOKENS: 21,
 255              TokenUsageKey.TOTAL_TOKENS: 36,
 256          }
 257  
 258      spans = trace.data.spans
 259      assert len(spans) == 15
 260      assert [span.name for span in spans] == [
 261          "ReAct.forward",
 262          "Predict.forward",
 263          "ChatAdapter.format",
 264          "DummyLMWithUsage.__call__",
 265          "ChatAdapter.parse",
 266          "Tool.search",
 267          "Predict.forward",
 268          "ChatAdapter.format",
 269          "DummyLMWithUsage.__call__",
 270          "ChatAdapter.parse",
 271          "ChainOfThought.forward",
 272          "Predict.forward",
 273          "ChatAdapter.format",
 274          "DummyLMWithUsage.__call__",
 275          "ChatAdapter.parse",
 276      ]
 277  
 278      assert spans[3].span_type == SpanType.CHAT_MODEL
 279      assert spans[3].model_name == "dummy"
 280      assert spans[8].span_type == SpanType.CHAT_MODEL
 281      assert spans[8].model_name == "dummy"
 282      assert spans[13].span_type == SpanType.CHAT_MODEL
 283      assert spans[13].model_name == "dummy"
 284  
 285  
 286  def test_autolog_retriever():
 287      mlflow.dspy.autolog()
 288  
 289      dspy.settings.configure(lm=DummyLM([{"output": "test output"}]))
 290  
 291      class DummyRetriever(dspy.Retrieve):
 292          def forward(self, query: str, n: int) -> list[str]:
 293              time.sleep(0.1)
 294              return ["test output"] * n
 295  
 296      retriever = DummyRetriever()
 297      result = retriever(query="test query", n=3)
 298      assert result == ["test output"] * 3
 299  
 300      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 301      assert trace is not None
 302      assert trace.info.status == "OK"
 303      assert trace.info.execution_time_ms > 0
 304  
 305      spans = trace.data.spans
 306      assert len(spans) == 1
 307      assert spans[0].name == "DummyRetriever.forward"
 308      assert spans[0].span_type == SpanType.RETRIEVER
 309      assert spans[0].status.status_code == "OK"
 310      assert spans[0].inputs == {"query": "test query", "n": 3}
 311      assert spans[0].outputs == ["test output"] * 3
 312  
 313  
 314  class DummyRetriever(dspy.Retrieve):
 315      def forward(self, query: str) -> list[str]:
 316          time.sleep(0.1)
 317          return ["test output"]
 318  
 319  
 320  class GenerateAnswer(dspy.Signature):
 321      """Answer questions with short factoid answers."""
 322  
 323      context = dspy.InputField(desc="may contain relevant facts")
 324      question = dspy.InputField()
 325      answer = dspy.OutputField(desc="often between 1 and 5 words")
 326  
 327  
 328  class RAG(dspy.Module):
 329      def __init__(self):
 330          super().__init__()
 331  
 332          self.retrieve = DummyRetriever()
 333          self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
 334  
 335      def forward(self, question):
 336          # Create a custom span inside the module using fluent API
 337          with mlflow.start_span(name="retrieve_context", span_type=SpanType.RETRIEVER) as span:
 338              span.set_inputs(question)
 339              docs = self.retrieve(question)
 340              context = "".join(docs)
 341              span.set_outputs(context)
 342          prediction = self.generate_answer(context=context, question=question)
 343          return dspy.Prediction(context=context, answer=prediction.answer)
 344  
 345  
 346  def test_autolog_custom_module():
 347      mlflow.dspy.autolog()
 348  
 349      dspy.settings.configure(
 350          lm=DummyLMWithUsage([
 351              {
 352                  "answer": "test output",
 353                  "reasoning": "No more responses",
 354              },
 355          ])
 356      )
 357  
 358      rag = RAG()
 359      result = rag("What castle did David Gregory inherit?")
 360      assert result.answer == "test output"
 361  
 362      traces = get_traces()
 363      assert len(traces) == 1, [trace.data.spans for trace in traces]
 364      assert traces[0] is not None
 365      assert traces[0].info.status == "OK"
 366      assert traces[0].info.execution_time_ms > 0
 367      if _DSPY_3_0_4_OR_NEWER:
 368          assert traces[0].info.token_usage == {
 369              TokenUsageKey.INPUT_TOKENS: 5,
 370              TokenUsageKey.OUTPUT_TOKENS: 7,
 371              TokenUsageKey.TOTAL_TOKENS: 12,
 372          }
 373  
 374      spans = traces[0].data.spans
 375      assert len(spans) == 8
 376      assert [span.name for span in spans] == [
 377          "RAG.forward",
 378          "retrieve_context",
 379          "DummyRetriever.forward",
 380          "ChainOfThought.forward",
 381          "Predict.forward",
 382          "ChatAdapter.format",
 383          "DummyLMWithUsage.__call__",
 384          "ChatAdapter.parse",
 385      ]
 386  
 387  
 388  def test_autolog_tracing_during_compilation_disabled_by_default():
 389      mlflow.dspy.autolog()
 390  
 391      dspy.settings.configure(
 392          lm=DummyLM({
 393              "What is 1 + 1?": {"answer": "2"},
 394              "What is 2 + 2?": {"answer": "1000"},
 395          })
 396      )
 397  
 398      # Samples from HotpotQA dataset
 399      trainset = [
 400          Example(question="What is 1 + 1?", answer="2").with_inputs("question"),
 401          Example(question="What is 2 + 2?", answer="4").with_inputs("question"),
 402      ]
 403  
 404      program = Predict("question -> answer")
 405  
 406      # Compile should NOT generate traces by default
 407      teleprompter = BootstrapFewShot()
 408      teleprompter.compile(program, trainset=trainset)
 409  
 410      assert len(get_traces()) == 0
 411  
 412      # If opted in, traces should be generated during compilation
 413      mlflow.dspy.autolog(log_traces_from_compile=True)
 414  
 415      teleprompter.compile(program, trainset=trainset)
 416  
 417      traces = get_traces()
 418      assert len(traces) == 2
 419      assert all(trace.info.status == "OK" for trace in traces)
 420  
 421      # Opt-out again
 422      mlflow.dspy.autolog(log_traces_from_compile=False)
 423  
 424      teleprompter.compile(program, trainset=trainset)
 425      assert len(get_traces()) == 2  # no new traces
 426  
 427  
 428  def test_autolog_tracing_during_evaluation_enabled_by_default():
 429      mlflow.dspy.autolog()
 430  
 431      dspy.settings.configure(
 432          lm=DummyLM({
 433              "What is 1 + 1?": {"answer": "2"},
 434              "What is 2 + 2?": {"answer": "1000"},
 435          })
 436      )
 437  
 438      # Samples from HotpotQA dataset
 439      trainset = [
 440          Example(question="What is 1 + 1?", answer="2").with_inputs("question"),
 441          Example(question="What is 2 + 2?", answer="4").with_inputs("question"),
 442      ]
 443  
 444      program = Predict("question -> answer")
 445  
 446      # Evaluate should generate traces by default
 447      evaluator = Evaluate(devset=trainset)
 448      eval_res = evaluator(program, metric=answer_exact_match)
 449  
 450      score = eval_res if isinstance(eval_res, float) else eval_res.score
 451      assert score == 50.0
 452      traces = get_traces()
 453      assert len(traces) == 2
 454      assert all(trace.info.status == "OK" for trace in traces)
 455  
 456      # If opted out, traces should NOT be generated during evaluation
 457      mlflow.dspy.autolog(log_traces_from_eval=False)
 458  
 459      score = evaluator(program, metric=answer_exact_match)
 460      assert len(get_traces()) == 2  # no new traces
 461  
 462  
 463  def test_autolog_should_not_override_existing_callbacks():
 464      class CustomCallback(BaseCallback):
 465          pass
 466  
 467      callback = CustomCallback()
 468  
 469      dspy.settings.configure(callbacks=[callback])
 470  
 471      mlflow.dspy.autolog()
 472      assert callback in dspy.settings.callbacks
 473  
 474      mlflow.dspy.autolog(disable=True)
 475      assert callback in dspy.settings.callbacks
 476  
 477  
 478  def test_disable_autolog():
 479      lm = DummyLM([{"output": "test output"}])
 480      mlflow.dspy.autolog()
 481      lm("test input")
 482  
 483      assert len(get_traces()) == 1
 484  
 485      mlflow.dspy.autolog(disable=True)
 486  
 487      lm("test input")
 488  
 489      # no additional trace should be created
 490      assert len(get_traces()) == 1
 491  
 492      mlflow.dspy.autolog(log_traces=False)
 493  
 494      lm("test input")
 495  
 496      # no additional trace should be created
 497      assert len(get_traces()) == 1
 498  
 499  
 500  @skip_when_testing_trace_sdk
 501  def test_autolog_set_retriever_schema():
 502      from mlflow.models.dependencies_schemas import DependenciesSchemasType, _clear_retriever_schema
 503  
 504      mlflow.dspy.autolog()
 505      dspy.settings.configure(
 506          lm=DummyLM([{"answer": answer, "reasoning": "reason"} for answer in ["4", "6", "8", "10"]])
 507      )
 508  
 509      with mlflow.start_run():
 510          model_info = mlflow.dspy.log_model(CoT(), name="model")
 511  
 512      # Reset retriever schema
 513      _clear_retriever_schema()
 514  
 515      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
 516      loaded_model.predict({"question": "What is 2 + 2?"})
 517  
 518      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 519      assert trace is not None
 520      assert trace.info.status == "OK"
 521      assert json.loads(trace.info.tags[DependenciesSchemasType.RETRIEVERS.value]) == [
 522          {
 523              "name": "retriever",
 524              "primary_key": "id",
 525              "text_column": "text",
 526              "doc_uri": "source",
 527              "other_columns": [],
 528          }
 529      ]
 530  
 531  
 532  @skip_when_testing_trace_sdk
 533  @pytest.mark.parametrize("with_dependencies_schema", [True, False])
 534  def test_dspy_auto_tracing_in_databricks_model_serving(with_dependencies_schema):
 535      from mlflow.models.dependencies_schemas import DependenciesSchemasType
 536  
 537      mlflow.dspy.autolog()
 538  
 539      dspy.settings.configure(
 540          lm=DummyLM(
 541              [
 542                  {
 543                      "answer": "test output",
 544                      "reasoning": "No more responses",
 545                  },
 546              ]
 547              * 2
 548          )
 549      )
 550  
 551      if with_dependencies_schema:
 552          mlflow.models.set_retriever_schema(
 553              primary_key="primary-key",
 554              text_column="text-column",
 555              doc_uri="doc-uri",
 556              other_columns=["column1", "column2"],
 557          )
 558  
 559      input_example = "What castle did David Gregory inherit?"
 560  
 561      with mlflow.start_run():
 562          model_info = mlflow.dspy.log_model(RAG(), name="model", input_example=input_example)
 563  
 564      databricks_request_id, response, trace_dict = score_in_model_serving(
 565          model_info.model_uri,
 566          input_example,
 567      )
 568  
 569      trace = Trace.from_dict(trace_dict)
 570      assert trace.info.trace_id.startswith("tr-")
 571      assert trace.info.client_request_id == databricks_request_id
 572      assert trace.info.status == "OK"
 573  
 574      spans = trace.data.spans
 575      assert len(spans) == 8
 576      assert [span.name for span in spans] == [
 577          "RAG.forward",
 578          "retrieve_context",
 579          "DummyRetriever.forward",
 580          "ChainOfThought.forward",
 581          "Predict.forward",
 582          "ChatAdapter.format",
 583          "DummyLM.__call__",
 584          "ChatAdapter.parse",
 585      ]
 586  
 587      if with_dependencies_schema:
 588          assert json.loads(trace.info.tags[DependenciesSchemasType.RETRIEVERS.value]) == [
 589              {
 590                  "name": "retriever",
 591                  "primary_key": "primary-key",
 592                  "text_column": "text-column",
 593                  "doc_uri": "doc-uri",
 594                  "other_columns": ["column1", "column2"],
 595              }
 596          ]
 597  
 598  
 599  @skip_when_testing_trace_sdk
 600  @pytest.mark.parametrize("log_compiles", [True, False])
 601  def test_autolog_log_compile(log_compiles):
 602      class DummyOptimizer(dspy.teleprompt.Teleprompter):
 603          def compile(self, program, kwarg1=None, kwarg2=None):
 604              callback = dspy.settings.callbacks[0]
 605              assert callback.optimizer_stack_level == 1
 606              return program
 607  
 608      mlflow.dspy.autolog(log_compiles=log_compiles)
 609      dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}]))
 610  
 611      program = dspy.ChainOfThought("question -> answer")
 612      optimizer = DummyOptimizer()
 613  
 614      optimizer.compile(program, kwarg1=1, kwarg2="2")
 615  
 616      assert dspy.settings.callbacks[0].optimizer_stack_level == 0
 617      if log_compiles:
 618          run = mlflow.last_active_run()
 619          assert run is not None
 620          assert run.data.params == {
 621              "kwarg1": "1",
 622              "kwarg2": "2",
 623              "lm_params": json.dumps({
 624                  "cache": True,
 625                  "max_tokens": 1000,
 626                  "model": "dummy",
 627                  "model_type": "chat",
 628                  "temperature": 0.0,
 629              }),
 630          }
 631          client = MlflowClient()
 632          artifacts = (x.path for x in client.list_artifacts(run.info.run_id))
 633          assert "best_model.json" in artifacts
 634  
 635          # verify that a dummy model output is logged
 636          run = client.get_run(run.info.run_id)
 637          assert len(run.outputs.model_outputs) == 1
 638          assert isinstance(run.outputs.model_outputs[0], LoggedModelOutput)
 639      else:
 640          assert mlflow.last_active_run() is None
 641  
 642  
 643  @skip_when_testing_trace_sdk
 644  @pytest.mark.parametrize("log_compiles", [True, False])
 645  def test_autolog_log_compile_log_model_output_when_failure(log_compiles):
 646      class DummyOptimizer(dspy.teleprompt.Teleprompter):
 647          def compile(self, program, kwarg1=None, kwarg2=None):
 648              raise Exception("test error")
 649  
 650      mlflow.dspy.autolog(log_compiles=log_compiles)
 651      dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}]))
 652  
 653      program = dspy.ChainOfThought("question -> answer")
 654      optimizer = DummyOptimizer()
 655  
 656      with pytest.raises(Exception, match="test error"):
 657          optimizer.compile(program, kwarg1=1, kwarg2="2")
 658  
 659      if log_compiles:
 660          run = mlflow.last_active_run()
 661          assert run is not None
 662  
 663          # verify that a dummy model output is logged even when compilation fails
 664          client = MlflowClient()
 665          run = client.get_run(run.info.run_id)
 666          assert len(run.outputs.model_outputs) == 1
 667          assert isinstance(run.outputs.model_outputs[0], LoggedModelOutput)
 668      else:
 669          assert mlflow.last_active_run() is None
 670  
 671  
 672  @skip_when_testing_trace_sdk
 673  def test_autolog_log_compile_disable():
 674      class DummyOptimizer(dspy.teleprompt.Teleprompter):
 675          def compile(self, program):
 676              return program
 677  
 678      mlflow.dspy.autolog(log_compiles=True)
 679      dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}]))
 680  
 681      program = dspy.ChainOfThought("question -> answer")
 682      optimizer = DummyOptimizer()
 683  
 684      optimizer.compile(program)
 685  
 686      run = mlflow.last_active_run()
 687      assert run is not None
 688  
 689      # verify that run is not created when disabling autologging
 690      mlflow.dspy.autolog(disable=True)
 691      optimizer.compile(program)
 692      client = MlflowClient()
 693      runs = client.search_runs(run.info.experiment_id)
 694      assert len(runs) == 1
 695  
 696  
 697  @skip_when_testing_trace_sdk
 698  def test_autolog_log_nested_compile():
 699      class NestedOptimizer(dspy.teleprompt.Teleprompter):
 700          def compile(self, program):
 701              callback = dspy.settings.callbacks[0]
 702              assert callback.optimizer_stack_level == 2
 703              return program
 704  
 705      class DummyOptimizer(dspy.teleprompt.Teleprompter):
 706          def __init__(self):
 707              super().__init__()
 708              self.nested_optimizer = NestedOptimizer()
 709  
 710          def compile(self, program):
 711              self.nested_optimizer.compile(program)
 712              callback = dspy.settings.callbacks[0]
 713              assert callback.optimizer_stack_level == 1
 714              return program
 715  
 716      mlflow.dspy.autolog(log_compiles=True)
 717      dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}]))
 718  
 719      program = dspy.ChainOfThought("question -> answer")
 720      optimizer = DummyOptimizer()
 721  
 722      optimizer.compile(program)
 723  
 724      assert dspy.settings.callbacks[0].optimizer_stack_level == 0
 725      run = mlflow.last_active_run()
 726      assert run is not None
 727      client = MlflowClient()
 728      artifacts = (x.path for x in client.list_artifacts(run.info.run_id))
 729      assert "best_model.json" in artifacts
 730  
 731  
 732  skip_if_evaluate_callback_unavailable = pytest.mark.skipif(
 733      Version(importlib.metadata.version("dspy")) < Version("2.6.12"),
 734      reason="evaluate callback is available since 2.6.12",
 735  )
 736  
 737  
 738  # Evaluate.call starts to return dspy.Prediction since 2.7.0
 739  is_2_7_or_newer = Version(importlib.metadata.version("dspy")) >= Version("2.7.0")
 740  
 741  
 742  @skip_when_testing_trace_sdk
 743  @skip_if_evaluate_callback_unavailable
 744  @pytest.mark.parametrize("log_evals", [True, False])
 745  @pytest.mark.parametrize("return_outputs", [True, False])
 746  @pytest.mark.parametrize(
 747      ("lm", "examples", "expected_result_table"),
 748      [
 749          (
 750              DummyLM({
 751                  "What is 1 + 1?": {"answer": "2"},
 752                  "What is 2 + 2?": {"answer": "1000"},
 753              }),
 754              [
 755                  Example(question="What is 1 + 1?", answer="2").with_inputs("question"),
 756                  Example(question="What is 2 + 2?", answer="4").with_inputs("question"),
 757              ],
 758              {
 759                  "columns": ["score", "example_question", "example_answer", "pred_answer"],
 760                  "data": [
 761                      [True, "What is 1 + 1?", "2", "2"],
 762                      [False, "What is 2 + 2?", "4", "1000"],
 763                  ],
 764              },
 765          ),
 766          (
 767              DummyLM({
 768                  "What is 1 + 1?": {"answer": "2"},
 769                  "What is 2 + 2?": {"answer": "1000"},
 770              }),
 771              [
 772                  Example(question="What is 1 + 1?", answer="2").with_inputs("question"),
 773                  Example(question="What is 2 + 2?", answer="4", reason="should be 4").with_inputs(
 774                      "question"
 775                  ),
 776              ],
 777              {
 778                  "columns": [
 779                      "score",
 780                      "example_question",
 781                      "example_answer",
 782                      "pred_answer",
 783                      "example_reason",
 784                  ],
 785                  "data": [
 786                      [True, "What is 1 + 1?", "2", "2", None],
 787                      [False, "What is 2 + 2?", "4", "1000", "should be 4"],
 788                  ],
 789              },
 790          ),
 791      ],
 792  )
 793  def test_autolog_log_evals(
 794      tmp_path, log_evals, return_outputs, lm, examples, expected_result_table
 795  ):
 796      mlflow.dspy.autolog(log_evals=log_evals)
 797  
 798      with dspy.context(lm=lm):
 799          program = Predict("question -> answer")
 800          if is_2_7_or_newer:
 801              evaluator = Evaluate(devset=examples, metric=answer_exact_match)
 802          else:
 803              # return_outputs arg does not exist after 2.7
 804              evaluator = Evaluate(
 805                  devset=examples, metric=answer_exact_match, return_outputs=return_outputs
 806              )
 807          evaluator(program, devset=examples)
 808  
 809      run = mlflow.last_active_run()
 810      if log_evals:
 811          assert run is not None
 812          assert run.data.metrics == {"eval": 50.0}
 813          assert run.data.params == {
 814              "Predict.signature.fields.0.description": "${question}",
 815              "Predict.signature.fields.0.prefix": "Question:",
 816              "Predict.signature.fields.1.description": "${answer}",
 817              "Predict.signature.fields.1.prefix": "Answer:",
 818              "Predict.signature.instructions": "Given the fields `question`, produce the fields `answer`.",  # noqa: E501
 819              "lm_params": json.dumps({
 820                  "cache": True,
 821                  "max_tokens": 1000,
 822                  "model": "dummy",
 823                  "model_type": "chat",
 824                  "temperature": 0.0,
 825              }),
 826          }
 827          client = MlflowClient()
 828          artifacts = (x.path for x in client.list_artifacts(run.info.run_id))
 829          assert "model.json" in artifacts
 830          if is_2_7_or_newer:
 831              assert "result_table.json" in artifacts
 832              client.download_artifacts(
 833                  run_id=run.info.run_id, path="result_table.json", dst_path=tmp_path
 834              )
 835              result_table = json.loads((tmp_path / "result_table.json").read_text())
 836              assert result_table == expected_result_table
 837      else:
 838          assert run is None
 839  
 840  
 841  @skip_when_testing_trace_sdk
 842  @skip_if_evaluate_callback_unavailable
 843  def test_autolog_log_evals_disable_by_caller():
 844      mlflow.dspy.autolog(log_evals=True)
 845      examples = [
 846          Example(question="What is 1 + 1?", answer="2").with_inputs("question"),
 847      ]
 848      evaluator = Evaluate(devset=examples, metric=answer_exact_match)
 849      program = Predict("question -> answer")
 850      with dspy.context(lm=DummyLM([{"answer": "2"}])):
 851          evaluator(program, devset=examples, callback_metadata={"disable_logging": True})
 852  
 853      assert mlflow.last_active_run() is None
 854  
 855  
 856  @skip_when_testing_trace_sdk
 857  @skip_if_evaluate_callback_unavailable
 858  def test_autolog_nested_evals():
 859      lm = DummyLM({
 860          "What is 1 + 1?": {"answer": "2"},
 861          "What is 2 + 2?": {"answer": "4"},
 862      })
 863      dspy.settings.configure(lm=lm)
 864      examples = [
 865          Example(question="What is 1 + 1?", answer="2").with_inputs("question"),
 866          Example(question="What is 2 + 2?", answer="2").with_inputs("question"),
 867      ]
 868      program = Predict("question -> answer")
 869      evaluator = Evaluate(devset=examples, metric=answer_exact_match)
 870  
 871      mlflow.dspy.autolog(log_evals=True)
 872      with mlflow.start_run() as active_run:
 873          evaluator(program, devset=examples[:1])
 874          evaluator(program, devset=examples[1:])
 875  
 876      client = MlflowClient()
 877      run = client.get_run(active_run.info.run_id)
 878      assert run.data.metrics == {"eval": 0.0}
 879  
 880      artifacts = (x.path for x in client.list_artifacts(run.info.run_id))
 881      assert "model.json" in artifacts
 882  
 883      metric_history = client.get_metric_history(run.info.run_id, "eval")
 884      assert [metric.value for metric in metric_history] == [100.0, 0.0]
 885  
 886      child_runs = client.search_runs(
 887          run.info.experiment_id,
 888          filter_string=f"tags.mlflow.parentRunId = '{run.info.run_id}'",
 889          order_by=["attributes.start_time ASC"],
 890      )
 891  
 892      assert len(child_runs) == 0
 893  
 894  
 895  @skip_when_testing_trace_sdk
 896  @skip_if_evaluate_callback_unavailable
 897  @pytest.mark.parametrize("call_args", ["args", "kwargs", "mixed"])
 898  def test_autolog_log_traces_from_evals(call_args):
 899      mlflow.dspy.autolog(log_evals=True, log_traces_from_eval=True)
 900      dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}]))
 901  
 902      class DummyProgram(dspy.Module):
 903          def forward(self, question):
 904              return dspy.Prediction(answer="2")
 905  
 906      examples = [
 907          Example(question="What is 1 + 1?", answer="2").with_inputs("question"),
 908          Example(question="What is 2 + 2?", answer="4").with_inputs("question"),
 909      ]
 910  
 911      program = DummyProgram()
 912      evaluator = Evaluate(devset=examples, metric=answer_exact_match)
 913  
 914      if call_args == "args":
 915          result = evaluator(program, answer_exact_match, examples)
 916      elif call_args == "kwargs":
 917          result = evaluator(program=program, devset=examples, metric=answer_exact_match)
 918      else:
 919          result = evaluator(program, answer_exact_match, devset=examples)
 920  
 921      if _DSPY_VERSION >= Version("3.0.0"):
 922          from dspy.evaluate.evaluate import EvaluationResult
 923  
 924          assert isinstance(result, EvaluationResult)
 925      else:
 926          assert result is not None
 927  
 928      traces = get_traces()
 929      assert len(traces) == 2
 930      assert all(trace.info.status == "OK" for trace in traces)
 931  
 932      actual_values = []
 933  
 934      assessments = traces[0].info.assessments
 935      assert len(assessments) == 1
 936      assert isinstance(assessments[0], Feedback)
 937      assert assessments[0].name == "answer_exact_match"
 938      actual_values.append(assessments[0].value)
 939  
 940      assessments = traces[1].info.assessments
 941      assert len(assessments) == 1
 942      assert isinstance(assessments[0], Feedback)
 943      assert assessments[0].name == "answer_exact_match"
 944      actual_values.append(assessments[0].value)
 945  
 946      assert set(actual_values) == {True, False}
 947  
 948  
 949  @skip_when_testing_trace_sdk
 950  @skip_if_evaluate_callback_unavailable
 951  def test_autolog_log_traces_from_evals_log_error_assessment():
 952      mlflow.dspy.autolog(log_evals=True, log_traces_from_eval=True)
 953      dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}]))
 954  
 955      class DummyProgram(dspy.Module):
 956          def forward(self, question):
 957              return dspy.Prediction(answer="2")
 958  
 959      def error_metric(program, devset):
 960          raise Exception("Error")
 961  
 962      examples = [Example(question="What is 1 + 1?", answer="2").with_inputs("question")]
 963  
 964      program = DummyProgram()
 965      evaluator = Evaluate(devset=examples, metric=error_metric)
 966      evaluator(program, error_metric, examples)
 967  
 968      traces = get_traces()
 969      assert len(traces) == 1
 970      assert traces[0].info.status == "OK"
 971  
 972      assessments = traces[0].info.assessments
 973      assert len(assessments) == 1
 974      assert isinstance(assessments[0], Feedback)
 975      assert assessments[0].name == "error_metric"
 976      assert assessments[0].value is None
 977      assert assessments[0].error.error_code == "Exception"
 978      assert assessments[0].error.error_message == "Error"
 979      assert assessments[0].error.stack_trace is not None
 980  
 981  
 982  @skip_when_testing_trace_sdk
 983  @skip_if_evaluate_callback_unavailable
 984  def test_autolog_log_compile_with_evals():
 985      class EvalOptimizer(dspy.teleprompt.Teleprompter):
 986          def compile(self, program, eval, trainset, valset):
 987              eval(program, devset=valset, callback_metadata={"metric_key": "eval_full"})
 988              eval(program, devset=trainset[:1], callback_metadata={"metric_key": "eval_minibatch"})
 989              eval(program, devset=valset, callback_metadata={"metric_key": "eval_full"})
 990              eval(program, devset=trainset[:1], callback_metadata={"metric_key": "eval_minibatch"})
 991              return program
 992  
 993      dspy.settings.configure(
 994          lm=DummyLM({
 995              "What is 1 + 1?": {"answer": "2"},
 996              "What is 2 + 2?": {"answer": "1000"},
 997          })
 998      )
 999      dataset = [
1000          Example(question="What is 1 + 1?", answer="2").with_inputs("question"),
1001          Example(question="What is 2 + 2?", answer="4").with_inputs("question"),
1002      ]
1003      program = Predict("question -> answer")
1004      evaluator = Evaluate(devset=dataset, metric=answer_exact_match)
1005      optimizer = EvalOptimizer()
1006  
1007      mlflow.dspy.autolog(log_compiles=True, log_evals=True)
1008      optimizer.compile(program, evaluator, trainset=dataset, valset=dataset)
1009  
1010      # callback state
1011      callback = dspy.settings.callbacks[0]
1012      assert callback.optimizer_stack_level == 0
1013      assert callback._call_id_to_metric_key == {}
1014      assert callback._evaluation_counter == {}
1015  
1016      # root run
1017      root_run = mlflow.last_active_run()
1018      assert root_run is not None
1019      client = MlflowClient()
1020      artifacts = (x.path for x in client.list_artifacts(root_run.info.run_id))
1021      assert "best_model.json" in artifacts
1022      assert "trainset.json" in artifacts
1023      assert "valset.json" in artifacts
1024      assert root_run.data.metrics == {
1025          "eval_full": 50.0,
1026          "eval_minibatch": 100.0,
1027      }
1028  
1029      # children runs
1030      child_runs = client.search_runs(
1031          root_run.info.experiment_id,
1032          filter_string=f"tags.mlflow.parentRunId = '{root_run.info.run_id}'",
1033          order_by=["attributes.start_time ASC"],
1034      )
1035      assert len(child_runs) == 4
1036  
1037      for i, run in enumerate(child_runs):
1038          if i % 2 == 0:
1039              assert run.data.metrics == {"eval": 50.0}
1040          else:
1041              assert run.data.metrics == {"eval": 100.0}
1042          artifacts = (x.path for x in client.list_artifacts(run.info.run_id))
1043          assert "model.json" in artifacts
1044          assert run.data.params == {
1045              "Predict.signature.fields.0.description": "${question}",
1046              "Predict.signature.fields.0.prefix": "Question:",
1047              "Predict.signature.fields.1.description": "${answer}",
1048              "Predict.signature.fields.1.prefix": "Answer:",
1049              "Predict.signature.instructions": "Given the fields `question`, produce the fields `answer`.",  # noqa: E501
1050              "lm_params": json.dumps({
1051                  "cache": True,
1052                  "max_tokens": 1000,
1053                  "model": "dummy",
1054                  "model_type": "chat",
1055                  "temperature": 0.0,
1056              }),
1057          }
1058  
1059  
1060  @skip_when_testing_trace_sdk
1061  def test_autolog_link_traces_loaded_model_custom_module():
1062      mlflow.dspy.autolog()
1063      dspy.settings.configure(
1064          lm=DummyLM([{"answer": "test output", "reasoning": "No more responses"}] * 5)
1065      )
1066      dspy_model = CoT()
1067  
1068      model_infos = []
1069      for _ in range(5):
1070          with mlflow.start_run():
1071              model_infos.append(mlflow.dspy.log_model(dspy_model, name="model", pip_requirements=[]))
1072  
1073      for model_info in model_infos:
1074          loaded_model = mlflow.dspy.load_model(model_info.model_uri)
1075          loaded_model(model_info.model_id)
1076  
1077      traces = get_traces()
1078      assert len(traces) == len(model_infos)
1079      for trace in traces:
1080          model_id = json.loads(trace.data.request)["args"][0]
1081          assert model_id == trace.info.request_metadata[TraceMetadataKey.MODEL_ID]
1082  
1083  
1084  @skip_when_testing_trace_sdk
1085  def test_autolog_link_traces_loaded_model_custom_module_pyfunc():
1086      mlflow.dspy.autolog()
1087      dspy.settings.configure(
1088          lm=DummyLM([{"answer": "test output", "reasoning": "No more responses"}] * 5)
1089      )
1090      dspy_model = CoT()
1091  
1092      model_infos = []
1093      for _ in range(5):
1094          with mlflow.start_run():
1095              model_infos.append(mlflow.dspy.log_model(dspy_model, name="model", pip_requirements=[]))
1096  
1097      for model_info in model_infos:
1098          pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
1099          pyfunc_model.predict(model_info.model_id)
1100  
1101      traces = get_traces()
1102      assert len(traces) == len(model_infos)
1103      for trace in traces:
1104          model_id = json.loads(trace.data.request)["args"][0]
1105          assert model_id == trace.info.request_metadata[TraceMetadataKey.MODEL_ID]
1106  
1107  
1108  @skip_when_testing_trace_sdk
1109  def test_autolog_link_traces_active_model():
1110      model = mlflow.create_external_model(name="test_model")
1111      mlflow.set_active_model(model_id=model.model_id)
1112      mlflow.dspy.autolog()
1113      dspy.settings.configure(
1114          lm=DummyLM([{"answer": "test output", "reasoning": "No more responses"}] * 5)
1115      )
1116      dspy_model = CoT()
1117  
1118      model_infos = []
1119      for _ in range(5):
1120          with mlflow.start_run():
1121              model_infos.append(mlflow.dspy.log_model(dspy_model, name="model", pip_requirements=[]))
1122  
1123      for model_info in model_infos:
1124          pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
1125          pyfunc_model.predict(model_info.model_id)
1126  
1127      traces = get_traces()
1128      assert len(traces) == len(model_infos)
1129      for trace in traces:
1130          model_id = json.loads(trace.data.request)["args"][0]
1131          assert model_id != model.model_id
1132          assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == model.model_id
1133  
1134  
1135  @skip_when_testing_trace_sdk
1136  def test_model_loading_set_active_model_id_without_fetching_logged_model():
1137      mlflow.dspy.autolog()
1138      dspy.settings.configure(
1139          lm=DummyLM([{"answer": "test output", "reasoning": "No more responses"}])
1140      )
1141      dspy_model = CoT()
1142  
1143      model_info = mlflow.dspy.log_model(dspy_model, name="model", pip_requirements=[])
1144  
1145      with mock.patch("mlflow.get_logged_model", side_effect=Exception("get_logged_model failed")):
1146          loaded_model = mlflow.dspy.load_model(model_info.model_uri)
1147      loaded_model(model_info.model_id)
1148  
1149      traces = get_traces()
1150      assert len(traces) == 1
1151      model_id = json.loads(traces[0].data.request)["args"][0]
1152      assert model_id == traces[0].info.request_metadata[TraceMetadataKey.MODEL_ID]
1153  
1154  
1155  def test_autolog_databricks_rm_retriever():
1156      mlflow.dspy.autolog()
1157  
1158      dspy.settings.configure(lm=DummyLM([{"output": "test output"}]))
1159  
1160      class DatabricksRM(dspy.Retrieve):
1161          def __init__(self, retrieve_uri):
1162              self.retrieve_uri = retrieve_uri
1163  
1164          def forward(self, query) -> list[str]:
1165              time.sleep(0.1)
1166              return dspy.Prediction(
1167                  docs=["doc1", "doc2"],
1168                  doc_ids=["id1", "id2"],
1169                  doc_uris=["uri1", "uri2"] if self.retrieve_uri else None,
1170                  extra_columns=[{"author": "Jim"}, {"author": "tom"}],
1171              )
1172  
1173      DatabricksRM.__module__ = "dspy.retrieve.databricks_rm"
1174  
1175      for retrieve_uri in [False, True]:
1176          retriever = DatabricksRM(retrieve_uri)
1177          result = retriever(query="test query")
1178          assert isinstance(result, dspy.Prediction)
1179  
1180          trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
1181          assert trace is not None
1182          assert trace.info.status == "OK"
1183          assert trace.info.execution_time_ms > 0
1184  
1185          spans = trace.data.spans
1186          assert len(spans) == 1
1187          assert spans[0].name == "DatabricksRM.forward"
1188          assert spans[0].span_type == SpanType.RETRIEVER
1189          assert spans[0].status.status_code == "OK"
1190          assert spans[0].inputs == {"query": "test query"}
1191  
1192          if retrieve_uri:
1193              uri1 = "uri1"
1194              uri2 = "uri2"
1195          else:
1196              uri1 = None
1197              uri2 = None
1198  
1199          assert spans[0].outputs == [
1200              {
1201                  "page_content": "doc1",
1202                  "metadata": {"doc_id": "id1", "doc_uri": uri1, "author": "Jim"},
1203                  "id": "id1",
1204              },
1205              {
1206                  "page_content": "doc2",
1207                  "metadata": {"doc_id": "id2", "doc_uri": uri2, "author": "tom"},
1208                  "id": "id2",
1209              },
1210          ]