test_dspy_autolog.py
1 import importlib 2 import json 3 import time 4 from unittest import mock 5 6 import dspy 7 import dspy.teleprompt 8 import pytest 9 from dspy.evaluate import Evaluate 10 from dspy.evaluate.metrics import answer_exact_match 11 from dspy.predict import Predict 12 from dspy.primitives.example import Example 13 from dspy.teleprompt import BootstrapFewShot 14 from dspy.utils.callback import BaseCallback, with_callbacks 15 from dspy.utils.dummies import DummyLM 16 from packaging.version import Version 17 18 import mlflow 19 from mlflow.entities import Feedback, LoggedModelOutput, SpanType, Trace 20 from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey 21 from mlflow.version import IS_TRACING_SDK_ONLY 22 23 from tests.tracing.helper import get_traces, score_in_model_serving, skip_when_testing_trace_sdk 24 25 if not IS_TRACING_SDK_ONLY: 26 from mlflow.tracking import MlflowClient 27 28 29 _DSPY_VERSION = Version(importlib.metadata.version("dspy")) 30 31 _DSPY_UNDER_2_6 = _DSPY_VERSION < Version("2.6.0rc1") 32 33 _DSPY_3_0_4_OR_NEWER = _DSPY_VERSION >= Version("3.0.4") 34 35 36 # Test module 37 class CoT(dspy.Module): 38 def __init__(self): 39 super().__init__() 40 self.prog = dspy.ChainOfThought("question -> answer") 41 mlflow.models.set_retriever_schema( 42 primary_key="id", 43 text_column="text", 44 doc_uri="source", 45 ) 46 47 def forward(self, question): 48 return self.prog(question=question) 49 50 51 class DummyLMWithUsage(DummyLM): 52 # Usage tracking had an issue before3.0.4 53 # and DummyLM.__call__ cannot be overridden in 2.5.x 54 if _DSPY_3_0_4_OR_NEWER: 55 56 def __call__(self, prompt=None, messages=None, **kwargs): 57 if dspy.settings.usage_tracker: 58 dspy.settings.usage_tracker.add_usage( 59 "openai/gpt-4.1", 60 { 61 "prompt_tokens": 5, 62 "completion_tokens": 7, 63 "total_tokens": 12, 64 }, 65 ) 66 67 return super().__call__(prompt, messages, **kwargs) 68 69 70 def test_autolog_lm(): 71 mlflow.dspy.autolog() 72 73 lm = DummyLMWithUsage([{"output": "test output"}]) 74 result = lm("test input") 75 assert result == ["[[ ## output ## ]]\ntest output"] 76 77 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 78 assert trace is not None 79 assert trace.info.status == "OK" 80 # Latency of LM is too small to get > 0 milliseconds difference 81 assert trace.info.execution_time_ms is not None 82 83 spans = trace.data.spans 84 assert len(spans) == 1 85 assert spans[0].name == "DummyLMWithUsage.__call__" 86 assert spans[0].span_type == SpanType.CHAT_MODEL 87 assert spans[0].status.status_code == "OK" 88 assert spans[0].inputs["prompt"] == "test input" 89 assert spans[0].outputs == ["[[ ## output ## ]]\ntest output"] 90 assert spans[0].attributes["model"] == "dummy" 91 assert spans[0].attributes["model_type"] == "chat" 92 assert spans[0].attributes["temperature"] == 0.0 93 assert spans[0].attributes["max_tokens"] == 1000 94 assert spans[0].model_name == "dummy" 95 96 97 def test_autolog_cot(): 98 mlflow.dspy.autolog() 99 100 dspy.settings.configure( 101 lm=DummyLMWithUsage({ 102 "How are you?": {"answer": "test output", "reasoning": "No more responses"} 103 }) 104 ) 105 106 cot = dspy.ChainOfThought("question -> answer", n=3) 107 108 result = cot(question="How are you?") 109 assert result["answer"] == "test output" 110 assert result["reasoning"] == "No more responses" 111 112 traces = get_traces() 113 assert len(traces) == 1 114 assert traces[0] is not None 115 assert traces[0].info.status == "OK" 116 assert traces[0].info.execution_time_ms > 0 117 if _DSPY_3_0_4_OR_NEWER: 118 assert traces[0].info.token_usage == { 119 TokenUsageKey.INPUT_TOKENS: 5, 120 TokenUsageKey.OUTPUT_TOKENS: 7, 121 TokenUsageKey.TOTAL_TOKENS: 12, 122 } 123 124 spans = traces[0].data.spans 125 assert len(spans) == 7 126 assert spans[0].name == "ChainOfThought.forward" 127 assert spans[0].span_type == SpanType.CHAIN 128 assert spans[0].status.status_code == "OK" 129 assert spans[0].inputs == {"question": "How are you?"} 130 assert spans[0].outputs == {"answer": "test output", "reasoning": "No more responses"} 131 assert ( 132 spans[0].attributes["signature"] == "question -> answer" 133 if _DSPY_UNDER_2_6 134 else "question -> reasoning, answer" 135 ) 136 if _DSPY_3_0_4_OR_NEWER: 137 assert spans[0].attributes[SpanAttributeKey.CHAT_USAGE] == { 138 TokenUsageKey.INPUT_TOKENS: 5, 139 TokenUsageKey.OUTPUT_TOKENS: 7, 140 TokenUsageKey.TOTAL_TOKENS: 12, 141 } 142 assert spans[1].name == "Predict.forward" 143 assert spans[1].span_type == SpanType.LLM 144 assert spans[1].inputs["question"] == "How are you?" 145 assert spans[1].outputs == {"answer": "test output", "reasoning": "No more responses"} 146 assert spans[2].name == "ChatAdapter.format" 147 assert spans[2].span_type == SpanType.PARSER 148 assert spans[2].inputs == { 149 "inputs": {"question": "How are you?"}, 150 "demos": mock.ANY, 151 "signature": mock.ANY, 152 } 153 assert spans[3].name == "DummyLMWithUsage.__call__" 154 assert spans[3].span_type == SpanType.CHAT_MODEL 155 assert spans[3].inputs == { 156 "prompt": None, 157 "messages": mock.ANY, 158 "n": 3, 159 "temperature": 0.7, 160 } 161 assert len(spans[3].outputs) == 3 162 assert spans[3].model_name == "dummy" 163 # Output parser will run per completion output (n=3) 164 for i in range(3): 165 assert spans[4 + i].name == "ChatAdapter.parse" 166 assert spans[4 + i].span_type == SpanType.PARSER 167 assert "question -> reasoning, answer" in spans[4 + i].inputs["signature"] 168 169 170 def test_mlflow_callback_exception(): 171 from litellm import ContextWindowExceededError 172 173 mlflow.dspy.autolog() 174 175 class ErrorLM(dspy.LM): 176 @with_callbacks 177 def __call__(self, prompt=None, messages=None, **kwargs): 178 time.sleep(0.1) 179 # pdpy.ChatAdapter falls back to JSONAdapter unless it's not ContextWindowExceededError 180 raise ContextWindowExceededError("Error", "invalid model", "provider") 181 182 cot = dspy.ChainOfThought("question -> answer", n=3) 183 184 with dspy.context( 185 lm=ErrorLM( 186 model="invalid", 187 prompt={"How are you?": {"answer": "test output", "reasoning": "No more responses"}}, 188 ), 189 ): 190 with pytest.raises(ContextWindowExceededError, match="Error"): 191 cot(question="How are you?") 192 193 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 194 assert trace is not None 195 assert trace.info.status == "ERROR" 196 assert trace.info.execution_time_ms > 0 197 198 spans = trace.data.spans 199 assert len(spans) == 4 200 assert spans[0].name == "ChainOfThought.forward" 201 assert spans[0].inputs == {"question": "How are you?"} 202 assert spans[0].outputs is None 203 assert spans[0].status.status_code == "ERROR" 204 assert spans[1].name == "Predict.forward" 205 assert spans[1].status.status_code == "ERROR" 206 assert spans[2].name == "ChatAdapter.format" 207 assert spans[2].status.status_code == "OK" 208 assert spans[3].name == "ErrorLM.__call__" 209 assert spans[3].status.status_code == "ERROR" 210 211 212 @pytest.mark.skipif( 213 _DSPY_VERSION < Version("2.5.42"), 214 reason="DSPy callback does not handle Tool in versions < 2.5.42", 215 ) 216 def test_autolog_react(): 217 mlflow.dspy.autolog() 218 219 dspy.settings.configure( 220 lm=DummyLMWithUsage([ 221 { 222 "next_thought": "I need to search for the highest mountain in the world", 223 "next_tool_name": "search", 224 "next_tool_args": {"query": "Highest mountain in the world"}, 225 }, 226 { 227 "next_thought": "I found the highest mountain in the world", 228 "next_tool_name": "finish", 229 "next_tool_args": {"answer": "Mount Everest"}, 230 }, 231 { 232 "answer": "Mount Everest", 233 "reasoning": "No more responses", 234 }, 235 ]), 236 adapter=dspy.ChatAdapter(), 237 ) 238 239 def search(query: str) -> list[str]: 240 return "Mount Everest" 241 242 tools = [dspy.Tool(search)] 243 react = dspy.ReAct("question -> answer", tools=tools) 244 result = react(question="What is the highest mountain in the world?") 245 assert result["answer"] == "Mount Everest" 246 247 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 248 assert trace is not None 249 assert trace.info.status == "OK" 250 assert trace.info.execution_time_ms > 0 251 if _DSPY_3_0_4_OR_NEWER: 252 assert trace.info.token_usage == { 253 TokenUsageKey.INPUT_TOKENS: 15, 254 TokenUsageKey.OUTPUT_TOKENS: 21, 255 TokenUsageKey.TOTAL_TOKENS: 36, 256 } 257 258 spans = trace.data.spans 259 assert len(spans) == 15 260 assert [span.name for span in spans] == [ 261 "ReAct.forward", 262 "Predict.forward", 263 "ChatAdapter.format", 264 "DummyLMWithUsage.__call__", 265 "ChatAdapter.parse", 266 "Tool.search", 267 "Predict.forward", 268 "ChatAdapter.format", 269 "DummyLMWithUsage.__call__", 270 "ChatAdapter.parse", 271 "ChainOfThought.forward", 272 "Predict.forward", 273 "ChatAdapter.format", 274 "DummyLMWithUsage.__call__", 275 "ChatAdapter.parse", 276 ] 277 278 assert spans[3].span_type == SpanType.CHAT_MODEL 279 assert spans[3].model_name == "dummy" 280 assert spans[8].span_type == SpanType.CHAT_MODEL 281 assert spans[8].model_name == "dummy" 282 assert spans[13].span_type == SpanType.CHAT_MODEL 283 assert spans[13].model_name == "dummy" 284 285 286 def test_autolog_retriever(): 287 mlflow.dspy.autolog() 288 289 dspy.settings.configure(lm=DummyLM([{"output": "test output"}])) 290 291 class DummyRetriever(dspy.Retrieve): 292 def forward(self, query: str, n: int) -> list[str]: 293 time.sleep(0.1) 294 return ["test output"] * n 295 296 retriever = DummyRetriever() 297 result = retriever(query="test query", n=3) 298 assert result == ["test output"] * 3 299 300 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 301 assert trace is not None 302 assert trace.info.status == "OK" 303 assert trace.info.execution_time_ms > 0 304 305 spans = trace.data.spans 306 assert len(spans) == 1 307 assert spans[0].name == "DummyRetriever.forward" 308 assert spans[0].span_type == SpanType.RETRIEVER 309 assert spans[0].status.status_code == "OK" 310 assert spans[0].inputs == {"query": "test query", "n": 3} 311 assert spans[0].outputs == ["test output"] * 3 312 313 314 class DummyRetriever(dspy.Retrieve): 315 def forward(self, query: str) -> list[str]: 316 time.sleep(0.1) 317 return ["test output"] 318 319 320 class GenerateAnswer(dspy.Signature): 321 """Answer questions with short factoid answers.""" 322 323 context = dspy.InputField(desc="may contain relevant facts") 324 question = dspy.InputField() 325 answer = dspy.OutputField(desc="often between 1 and 5 words") 326 327 328 class RAG(dspy.Module): 329 def __init__(self): 330 super().__init__() 331 332 self.retrieve = DummyRetriever() 333 self.generate_answer = dspy.ChainOfThought(GenerateAnswer) 334 335 def forward(self, question): 336 # Create a custom span inside the module using fluent API 337 with mlflow.start_span(name="retrieve_context", span_type=SpanType.RETRIEVER) as span: 338 span.set_inputs(question) 339 docs = self.retrieve(question) 340 context = "".join(docs) 341 span.set_outputs(context) 342 prediction = self.generate_answer(context=context, question=question) 343 return dspy.Prediction(context=context, answer=prediction.answer) 344 345 346 def test_autolog_custom_module(): 347 mlflow.dspy.autolog() 348 349 dspy.settings.configure( 350 lm=DummyLMWithUsage([ 351 { 352 "answer": "test output", 353 "reasoning": "No more responses", 354 }, 355 ]) 356 ) 357 358 rag = RAG() 359 result = rag("What castle did David Gregory inherit?") 360 assert result.answer == "test output" 361 362 traces = get_traces() 363 assert len(traces) == 1, [trace.data.spans for trace in traces] 364 assert traces[0] is not None 365 assert traces[0].info.status == "OK" 366 assert traces[0].info.execution_time_ms > 0 367 if _DSPY_3_0_4_OR_NEWER: 368 assert traces[0].info.token_usage == { 369 TokenUsageKey.INPUT_TOKENS: 5, 370 TokenUsageKey.OUTPUT_TOKENS: 7, 371 TokenUsageKey.TOTAL_TOKENS: 12, 372 } 373 374 spans = traces[0].data.spans 375 assert len(spans) == 8 376 assert [span.name for span in spans] == [ 377 "RAG.forward", 378 "retrieve_context", 379 "DummyRetriever.forward", 380 "ChainOfThought.forward", 381 "Predict.forward", 382 "ChatAdapter.format", 383 "DummyLMWithUsage.__call__", 384 "ChatAdapter.parse", 385 ] 386 387 388 def test_autolog_tracing_during_compilation_disabled_by_default(): 389 mlflow.dspy.autolog() 390 391 dspy.settings.configure( 392 lm=DummyLM({ 393 "What is 1 + 1?": {"answer": "2"}, 394 "What is 2 + 2?": {"answer": "1000"}, 395 }) 396 ) 397 398 # Samples from HotpotQA dataset 399 trainset = [ 400 Example(question="What is 1 + 1?", answer="2").with_inputs("question"), 401 Example(question="What is 2 + 2?", answer="4").with_inputs("question"), 402 ] 403 404 program = Predict("question -> answer") 405 406 # Compile should NOT generate traces by default 407 teleprompter = BootstrapFewShot() 408 teleprompter.compile(program, trainset=trainset) 409 410 assert len(get_traces()) == 0 411 412 # If opted in, traces should be generated during compilation 413 mlflow.dspy.autolog(log_traces_from_compile=True) 414 415 teleprompter.compile(program, trainset=trainset) 416 417 traces = get_traces() 418 assert len(traces) == 2 419 assert all(trace.info.status == "OK" for trace in traces) 420 421 # Opt-out again 422 mlflow.dspy.autolog(log_traces_from_compile=False) 423 424 teleprompter.compile(program, trainset=trainset) 425 assert len(get_traces()) == 2 # no new traces 426 427 428 def test_autolog_tracing_during_evaluation_enabled_by_default(): 429 mlflow.dspy.autolog() 430 431 dspy.settings.configure( 432 lm=DummyLM({ 433 "What is 1 + 1?": {"answer": "2"}, 434 "What is 2 + 2?": {"answer": "1000"}, 435 }) 436 ) 437 438 # Samples from HotpotQA dataset 439 trainset = [ 440 Example(question="What is 1 + 1?", answer="2").with_inputs("question"), 441 Example(question="What is 2 + 2?", answer="4").with_inputs("question"), 442 ] 443 444 program = Predict("question -> answer") 445 446 # Evaluate should generate traces by default 447 evaluator = Evaluate(devset=trainset) 448 eval_res = evaluator(program, metric=answer_exact_match) 449 450 score = eval_res if isinstance(eval_res, float) else eval_res.score 451 assert score == 50.0 452 traces = get_traces() 453 assert len(traces) == 2 454 assert all(trace.info.status == "OK" for trace in traces) 455 456 # If opted out, traces should NOT be generated during evaluation 457 mlflow.dspy.autolog(log_traces_from_eval=False) 458 459 score = evaluator(program, metric=answer_exact_match) 460 assert len(get_traces()) == 2 # no new traces 461 462 463 def test_autolog_should_not_override_existing_callbacks(): 464 class CustomCallback(BaseCallback): 465 pass 466 467 callback = CustomCallback() 468 469 dspy.settings.configure(callbacks=[callback]) 470 471 mlflow.dspy.autolog() 472 assert callback in dspy.settings.callbacks 473 474 mlflow.dspy.autolog(disable=True) 475 assert callback in dspy.settings.callbacks 476 477 478 def test_disable_autolog(): 479 lm = DummyLM([{"output": "test output"}]) 480 mlflow.dspy.autolog() 481 lm("test input") 482 483 assert len(get_traces()) == 1 484 485 mlflow.dspy.autolog(disable=True) 486 487 lm("test input") 488 489 # no additional trace should be created 490 assert len(get_traces()) == 1 491 492 mlflow.dspy.autolog(log_traces=False) 493 494 lm("test input") 495 496 # no additional trace should be created 497 assert len(get_traces()) == 1 498 499 500 @skip_when_testing_trace_sdk 501 def test_autolog_set_retriever_schema(): 502 from mlflow.models.dependencies_schemas import DependenciesSchemasType, _clear_retriever_schema 503 504 mlflow.dspy.autolog() 505 dspy.settings.configure( 506 lm=DummyLM([{"answer": answer, "reasoning": "reason"} for answer in ["4", "6", "8", "10"]]) 507 ) 508 509 with mlflow.start_run(): 510 model_info = mlflow.dspy.log_model(CoT(), name="model") 511 512 # Reset retriever schema 513 _clear_retriever_schema() 514 515 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 516 loaded_model.predict({"question": "What is 2 + 2?"}) 517 518 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 519 assert trace is not None 520 assert trace.info.status == "OK" 521 assert json.loads(trace.info.tags[DependenciesSchemasType.RETRIEVERS.value]) == [ 522 { 523 "name": "retriever", 524 "primary_key": "id", 525 "text_column": "text", 526 "doc_uri": "source", 527 "other_columns": [], 528 } 529 ] 530 531 532 @skip_when_testing_trace_sdk 533 @pytest.mark.parametrize("with_dependencies_schema", [True, False]) 534 def test_dspy_auto_tracing_in_databricks_model_serving(with_dependencies_schema): 535 from mlflow.models.dependencies_schemas import DependenciesSchemasType 536 537 mlflow.dspy.autolog() 538 539 dspy.settings.configure( 540 lm=DummyLM( 541 [ 542 { 543 "answer": "test output", 544 "reasoning": "No more responses", 545 }, 546 ] 547 * 2 548 ) 549 ) 550 551 if with_dependencies_schema: 552 mlflow.models.set_retriever_schema( 553 primary_key="primary-key", 554 text_column="text-column", 555 doc_uri="doc-uri", 556 other_columns=["column1", "column2"], 557 ) 558 559 input_example = "What castle did David Gregory inherit?" 560 561 with mlflow.start_run(): 562 model_info = mlflow.dspy.log_model(RAG(), name="model", input_example=input_example) 563 564 databricks_request_id, response, trace_dict = score_in_model_serving( 565 model_info.model_uri, 566 input_example, 567 ) 568 569 trace = Trace.from_dict(trace_dict) 570 assert trace.info.trace_id.startswith("tr-") 571 assert trace.info.client_request_id == databricks_request_id 572 assert trace.info.status == "OK" 573 574 spans = trace.data.spans 575 assert len(spans) == 8 576 assert [span.name for span in spans] == [ 577 "RAG.forward", 578 "retrieve_context", 579 "DummyRetriever.forward", 580 "ChainOfThought.forward", 581 "Predict.forward", 582 "ChatAdapter.format", 583 "DummyLM.__call__", 584 "ChatAdapter.parse", 585 ] 586 587 if with_dependencies_schema: 588 assert json.loads(trace.info.tags[DependenciesSchemasType.RETRIEVERS.value]) == [ 589 { 590 "name": "retriever", 591 "primary_key": "primary-key", 592 "text_column": "text-column", 593 "doc_uri": "doc-uri", 594 "other_columns": ["column1", "column2"], 595 } 596 ] 597 598 599 @skip_when_testing_trace_sdk 600 @pytest.mark.parametrize("log_compiles", [True, False]) 601 def test_autolog_log_compile(log_compiles): 602 class DummyOptimizer(dspy.teleprompt.Teleprompter): 603 def compile(self, program, kwarg1=None, kwarg2=None): 604 callback = dspy.settings.callbacks[0] 605 assert callback.optimizer_stack_level == 1 606 return program 607 608 mlflow.dspy.autolog(log_compiles=log_compiles) 609 dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}])) 610 611 program = dspy.ChainOfThought("question -> answer") 612 optimizer = DummyOptimizer() 613 614 optimizer.compile(program, kwarg1=1, kwarg2="2") 615 616 assert dspy.settings.callbacks[0].optimizer_stack_level == 0 617 if log_compiles: 618 run = mlflow.last_active_run() 619 assert run is not None 620 assert run.data.params == { 621 "kwarg1": "1", 622 "kwarg2": "2", 623 "lm_params": json.dumps({ 624 "cache": True, 625 "max_tokens": 1000, 626 "model": "dummy", 627 "model_type": "chat", 628 "temperature": 0.0, 629 }), 630 } 631 client = MlflowClient() 632 artifacts = (x.path for x in client.list_artifacts(run.info.run_id)) 633 assert "best_model.json" in artifacts 634 635 # verify that a dummy model output is logged 636 run = client.get_run(run.info.run_id) 637 assert len(run.outputs.model_outputs) == 1 638 assert isinstance(run.outputs.model_outputs[0], LoggedModelOutput) 639 else: 640 assert mlflow.last_active_run() is None 641 642 643 @skip_when_testing_trace_sdk 644 @pytest.mark.parametrize("log_compiles", [True, False]) 645 def test_autolog_log_compile_log_model_output_when_failure(log_compiles): 646 class DummyOptimizer(dspy.teleprompt.Teleprompter): 647 def compile(self, program, kwarg1=None, kwarg2=None): 648 raise Exception("test error") 649 650 mlflow.dspy.autolog(log_compiles=log_compiles) 651 dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}])) 652 653 program = dspy.ChainOfThought("question -> answer") 654 optimizer = DummyOptimizer() 655 656 with pytest.raises(Exception, match="test error"): 657 optimizer.compile(program, kwarg1=1, kwarg2="2") 658 659 if log_compiles: 660 run = mlflow.last_active_run() 661 assert run is not None 662 663 # verify that a dummy model output is logged even when compilation fails 664 client = MlflowClient() 665 run = client.get_run(run.info.run_id) 666 assert len(run.outputs.model_outputs) == 1 667 assert isinstance(run.outputs.model_outputs[0], LoggedModelOutput) 668 else: 669 assert mlflow.last_active_run() is None 670 671 672 @skip_when_testing_trace_sdk 673 def test_autolog_log_compile_disable(): 674 class DummyOptimizer(dspy.teleprompt.Teleprompter): 675 def compile(self, program): 676 return program 677 678 mlflow.dspy.autolog(log_compiles=True) 679 dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}])) 680 681 program = dspy.ChainOfThought("question -> answer") 682 optimizer = DummyOptimizer() 683 684 optimizer.compile(program) 685 686 run = mlflow.last_active_run() 687 assert run is not None 688 689 # verify that run is not created when disabling autologging 690 mlflow.dspy.autolog(disable=True) 691 optimizer.compile(program) 692 client = MlflowClient() 693 runs = client.search_runs(run.info.experiment_id) 694 assert len(runs) == 1 695 696 697 @skip_when_testing_trace_sdk 698 def test_autolog_log_nested_compile(): 699 class NestedOptimizer(dspy.teleprompt.Teleprompter): 700 def compile(self, program): 701 callback = dspy.settings.callbacks[0] 702 assert callback.optimizer_stack_level == 2 703 return program 704 705 class DummyOptimizer(dspy.teleprompt.Teleprompter): 706 def __init__(self): 707 super().__init__() 708 self.nested_optimizer = NestedOptimizer() 709 710 def compile(self, program): 711 self.nested_optimizer.compile(program) 712 callback = dspy.settings.callbacks[0] 713 assert callback.optimizer_stack_level == 1 714 return program 715 716 mlflow.dspy.autolog(log_compiles=True) 717 dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}])) 718 719 program = dspy.ChainOfThought("question -> answer") 720 optimizer = DummyOptimizer() 721 722 optimizer.compile(program) 723 724 assert dspy.settings.callbacks[0].optimizer_stack_level == 0 725 run = mlflow.last_active_run() 726 assert run is not None 727 client = MlflowClient() 728 artifacts = (x.path for x in client.list_artifacts(run.info.run_id)) 729 assert "best_model.json" in artifacts 730 731 732 skip_if_evaluate_callback_unavailable = pytest.mark.skipif( 733 Version(importlib.metadata.version("dspy")) < Version("2.6.12"), 734 reason="evaluate callback is available since 2.6.12", 735 ) 736 737 738 # Evaluate.call starts to return dspy.Prediction since 2.7.0 739 is_2_7_or_newer = Version(importlib.metadata.version("dspy")) >= Version("2.7.0") 740 741 742 @skip_when_testing_trace_sdk 743 @skip_if_evaluate_callback_unavailable 744 @pytest.mark.parametrize("log_evals", [True, False]) 745 @pytest.mark.parametrize("return_outputs", [True, False]) 746 @pytest.mark.parametrize( 747 ("lm", "examples", "expected_result_table"), 748 [ 749 ( 750 DummyLM({ 751 "What is 1 + 1?": {"answer": "2"}, 752 "What is 2 + 2?": {"answer": "1000"}, 753 }), 754 [ 755 Example(question="What is 1 + 1?", answer="2").with_inputs("question"), 756 Example(question="What is 2 + 2?", answer="4").with_inputs("question"), 757 ], 758 { 759 "columns": ["score", "example_question", "example_answer", "pred_answer"], 760 "data": [ 761 [True, "What is 1 + 1?", "2", "2"], 762 [False, "What is 2 + 2?", "4", "1000"], 763 ], 764 }, 765 ), 766 ( 767 DummyLM({ 768 "What is 1 + 1?": {"answer": "2"}, 769 "What is 2 + 2?": {"answer": "1000"}, 770 }), 771 [ 772 Example(question="What is 1 + 1?", answer="2").with_inputs("question"), 773 Example(question="What is 2 + 2?", answer="4", reason="should be 4").with_inputs( 774 "question" 775 ), 776 ], 777 { 778 "columns": [ 779 "score", 780 "example_question", 781 "example_answer", 782 "pred_answer", 783 "example_reason", 784 ], 785 "data": [ 786 [True, "What is 1 + 1?", "2", "2", None], 787 [False, "What is 2 + 2?", "4", "1000", "should be 4"], 788 ], 789 }, 790 ), 791 ], 792 ) 793 def test_autolog_log_evals( 794 tmp_path, log_evals, return_outputs, lm, examples, expected_result_table 795 ): 796 mlflow.dspy.autolog(log_evals=log_evals) 797 798 with dspy.context(lm=lm): 799 program = Predict("question -> answer") 800 if is_2_7_or_newer: 801 evaluator = Evaluate(devset=examples, metric=answer_exact_match) 802 else: 803 # return_outputs arg does not exist after 2.7 804 evaluator = Evaluate( 805 devset=examples, metric=answer_exact_match, return_outputs=return_outputs 806 ) 807 evaluator(program, devset=examples) 808 809 run = mlflow.last_active_run() 810 if log_evals: 811 assert run is not None 812 assert run.data.metrics == {"eval": 50.0} 813 assert run.data.params == { 814 "Predict.signature.fields.0.description": "${question}", 815 "Predict.signature.fields.0.prefix": "Question:", 816 "Predict.signature.fields.1.description": "${answer}", 817 "Predict.signature.fields.1.prefix": "Answer:", 818 "Predict.signature.instructions": "Given the fields `question`, produce the fields `answer`.", # noqa: E501 819 "lm_params": json.dumps({ 820 "cache": True, 821 "max_tokens": 1000, 822 "model": "dummy", 823 "model_type": "chat", 824 "temperature": 0.0, 825 }), 826 } 827 client = MlflowClient() 828 artifacts = (x.path for x in client.list_artifacts(run.info.run_id)) 829 assert "model.json" in artifacts 830 if is_2_7_or_newer: 831 assert "result_table.json" in artifacts 832 client.download_artifacts( 833 run_id=run.info.run_id, path="result_table.json", dst_path=tmp_path 834 ) 835 result_table = json.loads((tmp_path / "result_table.json").read_text()) 836 assert result_table == expected_result_table 837 else: 838 assert run is None 839 840 841 @skip_when_testing_trace_sdk 842 @skip_if_evaluate_callback_unavailable 843 def test_autolog_log_evals_disable_by_caller(): 844 mlflow.dspy.autolog(log_evals=True) 845 examples = [ 846 Example(question="What is 1 + 1?", answer="2").with_inputs("question"), 847 ] 848 evaluator = Evaluate(devset=examples, metric=answer_exact_match) 849 program = Predict("question -> answer") 850 with dspy.context(lm=DummyLM([{"answer": "2"}])): 851 evaluator(program, devset=examples, callback_metadata={"disable_logging": True}) 852 853 assert mlflow.last_active_run() is None 854 855 856 @skip_when_testing_trace_sdk 857 @skip_if_evaluate_callback_unavailable 858 def test_autolog_nested_evals(): 859 lm = DummyLM({ 860 "What is 1 + 1?": {"answer": "2"}, 861 "What is 2 + 2?": {"answer": "4"}, 862 }) 863 dspy.settings.configure(lm=lm) 864 examples = [ 865 Example(question="What is 1 + 1?", answer="2").with_inputs("question"), 866 Example(question="What is 2 + 2?", answer="2").with_inputs("question"), 867 ] 868 program = Predict("question -> answer") 869 evaluator = Evaluate(devset=examples, metric=answer_exact_match) 870 871 mlflow.dspy.autolog(log_evals=True) 872 with mlflow.start_run() as active_run: 873 evaluator(program, devset=examples[:1]) 874 evaluator(program, devset=examples[1:]) 875 876 client = MlflowClient() 877 run = client.get_run(active_run.info.run_id) 878 assert run.data.metrics == {"eval": 0.0} 879 880 artifacts = (x.path for x in client.list_artifacts(run.info.run_id)) 881 assert "model.json" in artifacts 882 883 metric_history = client.get_metric_history(run.info.run_id, "eval") 884 assert [metric.value for metric in metric_history] == [100.0, 0.0] 885 886 child_runs = client.search_runs( 887 run.info.experiment_id, 888 filter_string=f"tags.mlflow.parentRunId = '{run.info.run_id}'", 889 order_by=["attributes.start_time ASC"], 890 ) 891 892 assert len(child_runs) == 0 893 894 895 @skip_when_testing_trace_sdk 896 @skip_if_evaluate_callback_unavailable 897 @pytest.mark.parametrize("call_args", ["args", "kwargs", "mixed"]) 898 def test_autolog_log_traces_from_evals(call_args): 899 mlflow.dspy.autolog(log_evals=True, log_traces_from_eval=True) 900 dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}])) 901 902 class DummyProgram(dspy.Module): 903 def forward(self, question): 904 return dspy.Prediction(answer="2") 905 906 examples = [ 907 Example(question="What is 1 + 1?", answer="2").with_inputs("question"), 908 Example(question="What is 2 + 2?", answer="4").with_inputs("question"), 909 ] 910 911 program = DummyProgram() 912 evaluator = Evaluate(devset=examples, metric=answer_exact_match) 913 914 if call_args == "args": 915 result = evaluator(program, answer_exact_match, examples) 916 elif call_args == "kwargs": 917 result = evaluator(program=program, devset=examples, metric=answer_exact_match) 918 else: 919 result = evaluator(program, answer_exact_match, devset=examples) 920 921 if _DSPY_VERSION >= Version("3.0.0"): 922 from dspy.evaluate.evaluate import EvaluationResult 923 924 assert isinstance(result, EvaluationResult) 925 else: 926 assert result is not None 927 928 traces = get_traces() 929 assert len(traces) == 2 930 assert all(trace.info.status == "OK" for trace in traces) 931 932 actual_values = [] 933 934 assessments = traces[0].info.assessments 935 assert len(assessments) == 1 936 assert isinstance(assessments[0], Feedback) 937 assert assessments[0].name == "answer_exact_match" 938 actual_values.append(assessments[0].value) 939 940 assessments = traces[1].info.assessments 941 assert len(assessments) == 1 942 assert isinstance(assessments[0], Feedback) 943 assert assessments[0].name == "answer_exact_match" 944 actual_values.append(assessments[0].value) 945 946 assert set(actual_values) == {True, False} 947 948 949 @skip_when_testing_trace_sdk 950 @skip_if_evaluate_callback_unavailable 951 def test_autolog_log_traces_from_evals_log_error_assessment(): 952 mlflow.dspy.autolog(log_evals=True, log_traces_from_eval=True) 953 dspy.settings.configure(lm=DummyLM([{"answer": "4", "reasoning": "reason"}])) 954 955 class DummyProgram(dspy.Module): 956 def forward(self, question): 957 return dspy.Prediction(answer="2") 958 959 def error_metric(program, devset): 960 raise Exception("Error") 961 962 examples = [Example(question="What is 1 + 1?", answer="2").with_inputs("question")] 963 964 program = DummyProgram() 965 evaluator = Evaluate(devset=examples, metric=error_metric) 966 evaluator(program, error_metric, examples) 967 968 traces = get_traces() 969 assert len(traces) == 1 970 assert traces[0].info.status == "OK" 971 972 assessments = traces[0].info.assessments 973 assert len(assessments) == 1 974 assert isinstance(assessments[0], Feedback) 975 assert assessments[0].name == "error_metric" 976 assert assessments[0].value is None 977 assert assessments[0].error.error_code == "Exception" 978 assert assessments[0].error.error_message == "Error" 979 assert assessments[0].error.stack_trace is not None 980 981 982 @skip_when_testing_trace_sdk 983 @skip_if_evaluate_callback_unavailable 984 def test_autolog_log_compile_with_evals(): 985 class EvalOptimizer(dspy.teleprompt.Teleprompter): 986 def compile(self, program, eval, trainset, valset): 987 eval(program, devset=valset, callback_metadata={"metric_key": "eval_full"}) 988 eval(program, devset=trainset[:1], callback_metadata={"metric_key": "eval_minibatch"}) 989 eval(program, devset=valset, callback_metadata={"metric_key": "eval_full"}) 990 eval(program, devset=trainset[:1], callback_metadata={"metric_key": "eval_minibatch"}) 991 return program 992 993 dspy.settings.configure( 994 lm=DummyLM({ 995 "What is 1 + 1?": {"answer": "2"}, 996 "What is 2 + 2?": {"answer": "1000"}, 997 }) 998 ) 999 dataset = [ 1000 Example(question="What is 1 + 1?", answer="2").with_inputs("question"), 1001 Example(question="What is 2 + 2?", answer="4").with_inputs("question"), 1002 ] 1003 program = Predict("question -> answer") 1004 evaluator = Evaluate(devset=dataset, metric=answer_exact_match) 1005 optimizer = EvalOptimizer() 1006 1007 mlflow.dspy.autolog(log_compiles=True, log_evals=True) 1008 optimizer.compile(program, evaluator, trainset=dataset, valset=dataset) 1009 1010 # callback state 1011 callback = dspy.settings.callbacks[0] 1012 assert callback.optimizer_stack_level == 0 1013 assert callback._call_id_to_metric_key == {} 1014 assert callback._evaluation_counter == {} 1015 1016 # root run 1017 root_run = mlflow.last_active_run() 1018 assert root_run is not None 1019 client = MlflowClient() 1020 artifacts = (x.path for x in client.list_artifacts(root_run.info.run_id)) 1021 assert "best_model.json" in artifacts 1022 assert "trainset.json" in artifacts 1023 assert "valset.json" in artifacts 1024 assert root_run.data.metrics == { 1025 "eval_full": 50.0, 1026 "eval_minibatch": 100.0, 1027 } 1028 1029 # children runs 1030 child_runs = client.search_runs( 1031 root_run.info.experiment_id, 1032 filter_string=f"tags.mlflow.parentRunId = '{root_run.info.run_id}'", 1033 order_by=["attributes.start_time ASC"], 1034 ) 1035 assert len(child_runs) == 4 1036 1037 for i, run in enumerate(child_runs): 1038 if i % 2 == 0: 1039 assert run.data.metrics == {"eval": 50.0} 1040 else: 1041 assert run.data.metrics == {"eval": 100.0} 1042 artifacts = (x.path for x in client.list_artifacts(run.info.run_id)) 1043 assert "model.json" in artifacts 1044 assert run.data.params == { 1045 "Predict.signature.fields.0.description": "${question}", 1046 "Predict.signature.fields.0.prefix": "Question:", 1047 "Predict.signature.fields.1.description": "${answer}", 1048 "Predict.signature.fields.1.prefix": "Answer:", 1049 "Predict.signature.instructions": "Given the fields `question`, produce the fields `answer`.", # noqa: E501 1050 "lm_params": json.dumps({ 1051 "cache": True, 1052 "max_tokens": 1000, 1053 "model": "dummy", 1054 "model_type": "chat", 1055 "temperature": 0.0, 1056 }), 1057 } 1058 1059 1060 @skip_when_testing_trace_sdk 1061 def test_autolog_link_traces_loaded_model_custom_module(): 1062 mlflow.dspy.autolog() 1063 dspy.settings.configure( 1064 lm=DummyLM([{"answer": "test output", "reasoning": "No more responses"}] * 5) 1065 ) 1066 dspy_model = CoT() 1067 1068 model_infos = [] 1069 for _ in range(5): 1070 with mlflow.start_run(): 1071 model_infos.append(mlflow.dspy.log_model(dspy_model, name="model", pip_requirements=[])) 1072 1073 for model_info in model_infos: 1074 loaded_model = mlflow.dspy.load_model(model_info.model_uri) 1075 loaded_model(model_info.model_id) 1076 1077 traces = get_traces() 1078 assert len(traces) == len(model_infos) 1079 for trace in traces: 1080 model_id = json.loads(trace.data.request)["args"][0] 1081 assert model_id == trace.info.request_metadata[TraceMetadataKey.MODEL_ID] 1082 1083 1084 @skip_when_testing_trace_sdk 1085 def test_autolog_link_traces_loaded_model_custom_module_pyfunc(): 1086 mlflow.dspy.autolog() 1087 dspy.settings.configure( 1088 lm=DummyLM([{"answer": "test output", "reasoning": "No more responses"}] * 5) 1089 ) 1090 dspy_model = CoT() 1091 1092 model_infos = [] 1093 for _ in range(5): 1094 with mlflow.start_run(): 1095 model_infos.append(mlflow.dspy.log_model(dspy_model, name="model", pip_requirements=[])) 1096 1097 for model_info in model_infos: 1098 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 1099 pyfunc_model.predict(model_info.model_id) 1100 1101 traces = get_traces() 1102 assert len(traces) == len(model_infos) 1103 for trace in traces: 1104 model_id = json.loads(trace.data.request)["args"][0] 1105 assert model_id == trace.info.request_metadata[TraceMetadataKey.MODEL_ID] 1106 1107 1108 @skip_when_testing_trace_sdk 1109 def test_autolog_link_traces_active_model(): 1110 model = mlflow.create_external_model(name="test_model") 1111 mlflow.set_active_model(model_id=model.model_id) 1112 mlflow.dspy.autolog() 1113 dspy.settings.configure( 1114 lm=DummyLM([{"answer": "test output", "reasoning": "No more responses"}] * 5) 1115 ) 1116 dspy_model = CoT() 1117 1118 model_infos = [] 1119 for _ in range(5): 1120 with mlflow.start_run(): 1121 model_infos.append(mlflow.dspy.log_model(dspy_model, name="model", pip_requirements=[])) 1122 1123 for model_info in model_infos: 1124 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 1125 pyfunc_model.predict(model_info.model_id) 1126 1127 traces = get_traces() 1128 assert len(traces) == len(model_infos) 1129 for trace in traces: 1130 model_id = json.loads(trace.data.request)["args"][0] 1131 assert model_id != model.model_id 1132 assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == model.model_id 1133 1134 1135 @skip_when_testing_trace_sdk 1136 def test_model_loading_set_active_model_id_without_fetching_logged_model(): 1137 mlflow.dspy.autolog() 1138 dspy.settings.configure( 1139 lm=DummyLM([{"answer": "test output", "reasoning": "No more responses"}]) 1140 ) 1141 dspy_model = CoT() 1142 1143 model_info = mlflow.dspy.log_model(dspy_model, name="model", pip_requirements=[]) 1144 1145 with mock.patch("mlflow.get_logged_model", side_effect=Exception("get_logged_model failed")): 1146 loaded_model = mlflow.dspy.load_model(model_info.model_uri) 1147 loaded_model(model_info.model_id) 1148 1149 traces = get_traces() 1150 assert len(traces) == 1 1151 model_id = json.loads(traces[0].data.request)["args"][0] 1152 assert model_id == traces[0].info.request_metadata[TraceMetadataKey.MODEL_ID] 1153 1154 1155 def test_autolog_databricks_rm_retriever(): 1156 mlflow.dspy.autolog() 1157 1158 dspy.settings.configure(lm=DummyLM([{"output": "test output"}])) 1159 1160 class DatabricksRM(dspy.Retrieve): 1161 def __init__(self, retrieve_uri): 1162 self.retrieve_uri = retrieve_uri 1163 1164 def forward(self, query) -> list[str]: 1165 time.sleep(0.1) 1166 return dspy.Prediction( 1167 docs=["doc1", "doc2"], 1168 doc_ids=["id1", "id2"], 1169 doc_uris=["uri1", "uri2"] if self.retrieve_uri else None, 1170 extra_columns=[{"author": "Jim"}, {"author": "tom"}], 1171 ) 1172 1173 DatabricksRM.__module__ = "dspy.retrieve.databricks_rm" 1174 1175 for retrieve_uri in [False, True]: 1176 retriever = DatabricksRM(retrieve_uri) 1177 result = retriever(query="test query") 1178 assert isinstance(result, dspy.Prediction) 1179 1180 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 1181 assert trace is not None 1182 assert trace.info.status == "OK" 1183 assert trace.info.execution_time_ms > 0 1184 1185 spans = trace.data.spans 1186 assert len(spans) == 1 1187 assert spans[0].name == "DatabricksRM.forward" 1188 assert spans[0].span_type == SpanType.RETRIEVER 1189 assert spans[0].status.status_code == "OK" 1190 assert spans[0].inputs == {"query": "test query"} 1191 1192 if retrieve_uri: 1193 uri1 = "uri1" 1194 uri2 = "uri2" 1195 else: 1196 uri1 = None 1197 uri2 = None 1198 1199 assert spans[0].outputs == [ 1200 { 1201 "page_content": "doc1", 1202 "metadata": {"doc_id": "id1", "doc_uri": uri1, "author": "Jim"}, 1203 "id": "id1", 1204 }, 1205 { 1206 "page_content": "doc2", 1207 "metadata": {"doc_id": "id2", "doc_uri": uri2, "author": "tom"}, 1208 "id": "id2", 1209 }, 1210 ]