test_langchain_autolog.py
1 import json 2 import random 3 import time 4 from concurrent.futures import ThreadPoolExecutor 5 from operator import itemgetter 6 from typing import Any 7 from unittest import mock 8 9 import langchain_core 10 import pytest 11 from langchain_community.document_loaders import TextLoader 12 from langchain_community.vectorstores import FAISS 13 from langchain_core.callbacks.base import ( 14 AsyncCallbackHandler, 15 BaseCallbackHandler, 16 BaseCallbackManager, 17 ) 18 from langchain_core.callbacks.manager import CallbackManagerForLLMRun 19 from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel 20 from langchain_core.messages import ( 21 AIMessage, 22 BaseMessage, 23 HumanMessage, 24 SystemMessage, 25 ToolMessage, 26 ) 27 from langchain_core.output_parsers import StrOutputParser 28 from langchain_core.outputs import ChatGeneration, ChatResult 29 from langchain_core.prompts import PromptTemplate 30 from langchain_core.prompts.chat import ChatPromptTemplate 31 from langchain_core.runnables import RunnableLambda, RunnablePassthrough 32 from langchain_core.runnables.config import RunnableConfig 33 from langchain_core.runnables.router import RouterRunnable 34 from langchain_core.tools import tool 35 from langchain_openai import ChatOpenAI 36 from langchain_text_splitters.character import CharacterTextSplitter 37 from packaging import version 38 39 import mlflow 40 from mlflow.entities.span import SpanType 41 from mlflow.entities.trace import Trace 42 from mlflow.entities.trace_status import TraceStatus 43 from mlflow.tracing.constant import TRACE_SCHEMA_VERSION_KEY, SpanAttributeKey, TraceMetadataKey 44 from mlflow.version import IS_TRACING_SDK_ONLY 45 46 from tests.langchain.conftest import DeterministicDummyEmbeddings 47 from tests.tracing.conftest import async_logging_enabled 48 from tests.tracing.helper import ( 49 get_traces, 50 purge_traces, 51 score_in_model_serving, 52 skip_when_testing_trace_sdk, 53 ) 54 55 MODEL_DIR = "model" 56 # The mock OpenAI endpoint simply echos the prompt back as the completion. 57 # So the expected output will be the prompt itself. 58 TEST_CONTENT = "What is MLflow?" 59 60 _SIMPLE_MODEL_CODE_PATH = "tests/langchain/sample_code/simple_runnable.py" 61 62 IS_LANGCHAIN_v1 = version.parse(langchain_core.__version__).major >= 1 63 64 65 def create_openai_runnable(temperature=0.9): 66 prompt = PromptTemplate( 67 input_variables=["product"], 68 template="What is {product}?", 69 ) 70 llm = ChatOpenAI(temperature=temperature, stream_usage=True) 71 return prompt | llm | StrOutputParser() 72 73 74 @pytest.fixture 75 def model_info(): 76 with mlflow.start_run(): 77 return mlflow.langchain.log_model(_SIMPLE_MODEL_CODE_PATH, pip_requirements=["mlflow"]) 78 79 80 @pytest.fixture 81 def model_infos(): 82 model_infos = [] 83 for _ in range(3): 84 with mlflow.start_run(): 85 info = mlflow.langchain.log_model(_SIMPLE_MODEL_CODE_PATH, pip_requirements=["mlflow"]) 86 model_infos.append(info) 87 return model_infos 88 89 90 def create_retriever(tmp_path): 91 # Create the vector db, persist the db to a local fs folder 92 loader = TextLoader("tests/langchain/state_of_the_union.txt") 93 documents = loader.load() 94 text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) 95 docs = text_splitter.split_documents(documents) 96 embeddings = DeterministicDummyEmbeddings(size=5) 97 db = FAISS.from_documents(docs, embeddings) 98 persist_dir = str(tmp_path / "faiss_index") 99 db.save_local(persist_dir) 100 query = "What did the president say about Ketanji Brown Jackson" 101 return db.as_retriever(), query 102 103 104 def create_fake_chat_model(): 105 class FakeChatModel(SimpleChatModel): 106 """Fake Chat Model wrapper for testing purposes.""" 107 108 def _call( 109 self, 110 messages: list[BaseMessage], 111 stop: list[str] | None = None, 112 run_manager: CallbackManagerForLLMRun | None = None, 113 **kwargs: Any, 114 ) -> str: 115 return TEST_CONTENT 116 117 @property 118 def _llm_type(self) -> str: 119 return "fake chat model" 120 121 return FakeChatModel() 122 123 124 def create_runnable_sequence(): 125 prompt_with_history_str = """ 126 Here is a history between you and a human: {chat_history} 127 128 Now, please answer this question: {question} 129 """ 130 prompt_with_history = PromptTemplate( 131 input_variables=["chat_history", "question"], template=prompt_with_history_str 132 ) 133 134 def extract_question(input): 135 return input[-1]["content"] 136 137 def extract_history(input): 138 return input[:-1] 139 140 chat_model = create_fake_chat_model() 141 chain_with_history = ( 142 { 143 "question": itemgetter("messages") | RunnableLambda(extract_question), 144 "chat_history": itemgetter("messages") | RunnableLambda(extract_history), 145 } 146 | prompt_with_history 147 | chat_model 148 | StrOutputParser() 149 ) 150 input_example = {"messages": [{"role": "user", "content": "Who owns MLflow?"}]} 151 return chain_with_history, input_example 152 153 154 def test_autolog_record_exception(async_logging_enabled): 155 def always_fail(input): 156 raise Exception("Error!") 157 158 model = RunnableLambda(always_fail) 159 160 mlflow.langchain.autolog() 161 162 with pytest.raises(Exception, match="Error!"): 163 model.invoke("test") 164 165 if async_logging_enabled: 166 mlflow.flush_trace_async_logging(terminate=True) 167 168 traces = get_traces() 169 assert len(traces) == 1 170 trace = traces[0] 171 assert trace.info.status == "ERROR" 172 assert len(trace.data.spans) == 1 173 assert trace.data.spans[0].name == "always_fail" 174 175 176 def test_chat_model_autolog(): 177 mlflow.langchain.autolog() 178 model = ChatOpenAI(model="gpt-4o-mini", temperature=0.9) 179 messages = [ 180 SystemMessage(content="You are a helpful assistant."), 181 HumanMessage(content="What is the weather in San Francisco?"), 182 AIMessage( 183 content="foo", 184 tool_calls=[{"name": "GetWeather", "args": {"location": "San Francisco"}, "id": "123"}], 185 ), 186 ToolMessage(content="Weather in San Francisco is 70F.", tool_call_id="123"), 187 ] 188 response = model.invoke(messages) 189 190 traces = get_traces() 191 assert len(traces) == 1 192 assert len(traces[0].data.spans) == 1 193 194 span = traces[0].data.spans[0] 195 assert span.name == "ChatOpenAI" 196 assert span.span_type == "CHAT_MODEL" 197 _LC_TYPE_TO_ROLE = {"human": "user", "ai": "assistant", "system": "system", "tool": "tool"} 198 for msg, expected in zip(span.inputs["messages"], messages, strict=True): 199 assert msg["role"] == _LC_TYPE_TO_ROLE[expected.type] 200 assert msg["content"] == expected.content 201 assert span.outputs["choices"][0]["message"]["content"] == response.content 202 assert span.get_attribute("invocation_params")["model"] == "gpt-4o-mini" 203 assert span.get_attribute("invocation_params")["temperature"] == 0.9 204 assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "langchain" 205 assert span.model_name == "gpt-4o-mini" 206 207 208 @pytest.mark.parametrize( 209 ("mime_type", "expected_format"), 210 [ 211 ("audio/wav", "wav"), 212 ("audio/mpeg", "mp3"), 213 ], 214 ) 215 def test_chat_model_autolog_audio_input_normalization(mime_type, expected_format): 216 audio_b64 = "SGVsbG8=" 217 218 class AudioInputModel(BaseChatModel): 219 def _generate(self, messages, stop=None, run_manager=None, **kwargs): 220 return ChatResult(generations=[ChatGeneration(message=AIMessage(content="heard it"))]) 221 222 @property 223 def _llm_type(self): 224 return "audio-input-model" 225 226 mlflow.langchain.autolog() 227 model = AudioInputModel() 228 model.invoke([ 229 HumanMessage( 230 content=[ 231 {"type": "text", "text": "What is this?"}, 232 { 233 "type": "audio", 234 "source_type": "base64", 235 "data": audio_b64, 236 "mime_type": mime_type, 237 }, 238 ] 239 ) 240 ]) 241 242 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 243 span = next(s for s in trace.data.spans if s.span_type == "CHAT_MODEL") 244 245 msgs = span.inputs["messages"] 246 audio_block = msgs[0]["content"][1] 247 assert audio_block["type"] == "input_audio" 248 assert audio_block["input_audio"]["format"] == expected_format 249 attachment_uri = audio_block["input_audio"]["data"] 250 assert attachment_uri.startswith("mlflow-attachment://") 251 expected_mime = "mpeg" if expected_format == "mp3" else expected_format 252 assert f"content_type=audio%2F{expected_mime}" in attachment_uri 253 254 255 def test_chat_model_autolog_audio_output_normalization(): 256 audio_b64 = "SGVsbG8=" 257 258 class AudioOutputModel(BaseChatModel): 259 def _generate(self, messages, stop=None, run_manager=None, **kwargs): 260 ai_msg = AIMessage( 261 content=[ 262 {"type": "text", "text": "Here is audio."}, 263 { 264 "type": "audio", 265 "source_type": "base64", 266 "data": audio_b64, 267 "mime_type": "audio/wav", 268 }, 269 ] 270 ) 271 return ChatResult(generations=[ChatGeneration(message=ai_msg)]) 272 273 @property 274 def _llm_type(self): 275 return "audio-output-model" 276 277 mlflow.langchain.autolog() 278 model = AudioOutputModel() 279 model.invoke([("human", "Give me audio")]) 280 281 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 282 span = next(s for s in trace.data.spans if s.span_type == "CHAT_MODEL") 283 284 audio_block = span.outputs["choices"][0]["message"]["content"][1] 285 assert audio_block["type"] == "input_audio" 286 assert audio_block["input_audio"]["format"] == "wav" 287 attachment_uri = audio_block["input_audio"]["data"] 288 assert attachment_uri.startswith("mlflow-attachment://") 289 assert "content_type=audio%2Fwav" in attachment_uri 290 291 292 def test_chat_model_autolog_openai_audio_output_with_format(): 293 audio_b64 = "SGVsbG8=" 294 295 class OpenAIAudioModelWithFormat(BaseChatModel): 296 def _generate(self, messages, stop=None, run_manager=None, **kwargs): 297 ai_msg = AIMessage( 298 content="", 299 additional_kwargs={ 300 "audio": { 301 "id": "audio_abc123", 302 "data": audio_b64, 303 "expires_at": 9999999999, 304 "transcript": "Yes, I am.", 305 } 306 }, 307 ) 308 return ChatResult(generations=[ChatGeneration(message=ai_msg)]) 309 310 @property 311 def _llm_type(self): 312 return "openai-audio-model" 313 314 @property 315 def _identifying_params(self): 316 return { 317 "model": "gpt-4o-audio-preview", 318 "audio": {"voice": "alloy", "format": "wav"}, 319 } 320 321 mlflow.langchain.autolog() 322 model = OpenAIAudioModelWithFormat() 323 model.invoke([("human", "Are you an AI?")]) 324 325 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 326 span = next(s for s in trace.data.spans if s.span_type == "CHAT_MODEL") 327 328 content = span.outputs["choices"][0]["message"]["content"] 329 assert isinstance(content, list) 330 assert content[0] == {"type": "text", "text": "Yes, I am."} 331 assert content[1]["type"] == "input_audio" 332 attachment_uri = content[1]["input_audio"]["data"] 333 assert attachment_uri.startswith("mlflow-attachment://") 334 assert "content_type=audio%2Fwav" in attachment_uri 335 assert content[1]["input_audio"]["format"] == "wav" 336 337 338 def test_chat_model_autolog_openai_audio_transcript_fallback(): 339 340 class OpenAIAudioModel(BaseChatModel): 341 def _generate(self, messages, stop=None, run_manager=None, **kwargs): 342 ai_msg = AIMessage( 343 content="", 344 additional_kwargs={ 345 "audio": { 346 "id": "audio_abc123", 347 "data": "SGVsbG8=", 348 "expires_at": 9999999999, 349 "transcript": "Yes, I am.", 350 } 351 }, 352 ) 353 return ChatResult(generations=[ChatGeneration(message=ai_msg)]) 354 355 @property 356 def _llm_type(self): 357 return "openai-audio-model" 358 359 mlflow.langchain.autolog() 360 model = OpenAIAudioModel() 361 model.invoke([("human", "Are you an AI?")]) 362 363 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 364 span = next(s for s in trace.data.spans if s.span_type == "CHAT_MODEL") 365 366 assert span.outputs["choices"][0]["message"]["content"] == "Yes, I am." 367 368 369 def test_chat_model_autolog_openai_audio_transcript_no_override(): 370 class AudioModelWithContent(BaseChatModel): 371 def _generate(self, messages, stop=None, run_manager=None, **kwargs): 372 ai_msg = AIMessage( 373 content="I have text content.", 374 additional_kwargs={ 375 "audio": { 376 "id": "audio_abc123", 377 "data": "SGVsbG8=", 378 "expires_at": 9999999999, 379 "transcript": "Different transcript.", 380 } 381 }, 382 ) 383 return ChatResult(generations=[ChatGeneration(message=ai_msg)]) 384 385 @property 386 def _llm_type(self): 387 return "audio-model-with-content" 388 389 mlflow.langchain.autolog() 390 model = AudioModelWithContent() 391 model.invoke([("human", "Say something")]) 392 393 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 394 span = next(s for s in trace.data.spans if s.span_type == "CHAT_MODEL") 395 396 assert span.outputs["choices"][0]["message"]["content"] == "I have text content." 397 398 399 def test_chat_model_bind_tool_autolog(): 400 mlflow.langchain.autolog() 401 402 @tool 403 def get_weather(location: str) -> str: 404 """Get the weather for a location.""" 405 return f"Weather in {location} is 70F." 406 407 model = ChatOpenAI(model="gpt-4o-mini", temperature=0.9) 408 model_with_tools = model.bind_tools([get_weather]) 409 model_with_tools.invoke("What is the weather in San Francisco?") 410 411 traces = get_traces() 412 assert len(traces) == 1 413 assert len(traces[0].data.spans) == 1 414 415 span = traces[0].data.spans[0] 416 assert span.name == "ChatOpenAI" 417 assert span.get_attribute(SpanAttributeKey.CHAT_TOOLS) == [ 418 { 419 "type": "function", 420 "function": { 421 "name": "get_weather", 422 "description": "Get the weather for a location.", 423 "parameters": { 424 "properties": { 425 "location": { 426 "type": "string", 427 } 428 }, 429 "required": ["location"], 430 "type": "object", 431 }, 432 }, 433 } 434 ] 435 assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "langchain" 436 assert span.model_name == "gpt-4o-mini" 437 438 439 @pytest.mark.skipif(not IS_LANGCHAIN_v1, reason="create_agent is not supported in langchain v0") 440 @skip_when_testing_trace_sdk 441 def test_agent_autolog(async_logging_enabled): 442 mlflow.langchain.autolog() 443 444 # Load the agent definition (with OpenAI mock) from the sample script 445 from langchain.agents import create_agent 446 447 from tests.langchain.sample_code.openai_agent import FakeOpenAI, add, multiply 448 449 model = create_agent(FakeOpenAI(), [add, multiply], system_prompt="You are a helpful assistant") 450 prompt = "What is 2 * 3?" 451 expected_output = "The result of 2 * 3 is 6." 452 453 result = model.invoke({"messages": [HumanMessage(content=prompt)]}) 454 assert result["messages"][-1].content == expected_output 455 456 if async_logging_enabled: 457 mlflow.flush_trace_async_logging(terminate=True) 458 459 traces = get_traces() 460 assert len(traces) == 1 461 assert len(traces[0].data.spans) == 7 462 spans = traces[0].data.spans 463 assert spans[0].name == "LangGraph" 464 assert spans[0].span_type == SpanType.CHAIN 465 assert spans[0].inputs["messages"][0]["content"] == prompt 466 assert spans[0].outputs["messages"][-1]["content"] == expected_output 467 llm_spans = [s for s in spans if s.span_type == SpanType.CHAT_MODEL] 468 assert len(llm_spans) == 2 469 assert all(s.name == "FakeOpenAI" for s in llm_spans) 470 tool_spans = [s for s in traces[0].data.spans if s.span_type == SpanType.TOOL] 471 assert len(tool_spans) == 1 472 assert tool_spans[0].name == "multiply" 473 assert tool_spans[0].inputs["a"] == 2 474 assert tool_spans[0].inputs["b"] == 3 475 assert tool_spans[0].outputs["content"] == "6" 476 477 478 def test_runnable_sequence_autolog(async_logging_enabled): 479 mlflow.langchain.autolog() 480 chain, input_example = create_runnable_sequence() 481 assert chain.invoke(input_example) == TEST_CONTENT 482 483 if async_logging_enabled: 484 mlflow.flush_trace_async_logging(terminate=True) 485 486 traces = get_traces() 487 assert len(traces) == 1 488 for trace in traces: 489 spans = {(s.name, s.span_type) for s in trace.data.spans} 490 # Since the chain includes parallel execution, the order of some 491 # spans is not deterministic. 492 assert spans == { 493 ("RunnableSequence", "CHAIN"), 494 ("RunnableParallel<question,chat_history>", "CHAIN"), 495 ("RunnableSequence", "CHAIN"), 496 ("RunnableLambda", "CHAIN"), 497 ("extract_question", "CHAIN"), 498 ("RunnableSequence", "CHAIN"), 499 ("RunnableLambda", "CHAIN"), 500 ("extract_history", "CHAIN"), 501 ("PromptTemplate", "CHAIN"), 502 ("FakeChatModel", "CHAT_MODEL"), 503 ("StrOutputParser", "CHAIN"), 504 } 505 506 507 def test_retriever_autolog(tmp_path, async_logging_enabled): 508 mlflow.langchain.autolog() 509 model, query = create_retriever(tmp_path) 510 model.invoke(query) 511 512 if async_logging_enabled: 513 mlflow.flush_trace_async_logging(terminate=True) 514 515 traces = get_traces() 516 assert len(traces) == 1 517 spans = traces[0].data.spans 518 assert len(spans) == 1 519 assert spans[0].name == "VectorStoreRetriever" 520 assert spans[0].span_type == "RETRIEVER" 521 assert spans[0].inputs == query 522 assert spans[0].outputs[0]["metadata"] == {"source": "tests/langchain/state_of_the_union.txt"} 523 524 525 class CustomCallbackHandler(BaseCallbackHandler): 526 def __init__(self): 527 self.logs = [] 528 529 def on_chain_start( 530 self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any 531 ) -> None: 532 self.logs.append("chain_start") 533 534 def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: 535 self.logs.append("chain_end") 536 537 538 class AsyncCustomCallbackHandler(AsyncCallbackHandler): 539 def __init__(self): 540 self.logs = [] 541 542 async def on_chain_start( 543 self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any 544 ) -> None: 545 self.logs.append("chain_start") 546 547 async def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: 548 self.logs.append("chain_end") 549 550 551 _CONFIG_PATTERNS = [ 552 # Config with no user callbacks 553 RunnableConfig(max_concurrency=1), 554 RunnableConfig(callbacks=None), 555 # With user callbacks 556 RunnableConfig(callbacks=[CustomCallbackHandler()]), 557 RunnableConfig(callbacks=BaseCallbackManager([CustomCallbackHandler()])), 558 ] 559 560 _ASYNC_CONFIG_PATTERNS = [ 561 RunnableConfig(callbacks=[AsyncCustomCallbackHandler()]), 562 RunnableConfig(callbacks=BaseCallbackManager([AsyncCustomCallbackHandler()])), 563 ] 564 565 566 def _reset_callback_handlers(handlers): 567 if handlers: 568 for handler in handlers: 569 handler.logs = [] 570 571 572 def _extract_callback_handlers(config) -> list[BaseCallbackHandler] | None: 573 if isinstance(config, list): 574 callbacks = [] 575 for c in config: 576 if callbacks_in_c := _extract_callback_handlers(c): 577 callbacks.extend(callbacks_in_c) 578 return callbacks 579 # RunnableConfig is also a dict 580 elif isinstance(config, dict) and "callbacks" in config: 581 callbacks = config["callbacks"] 582 if isinstance(callbacks, BaseCallbackManager): 583 return callbacks.handlers 584 else: 585 return callbacks 586 else: 587 return None 588 589 590 @pytest.mark.parametrize("invoke_arg", ["args", "kwargs", None]) 591 @pytest.mark.parametrize("config", _CONFIG_PATTERNS) 592 def test_langchain_autolog_callback_injection_in_invoke(invoke_arg, config, async_logging_enabled): 593 mlflow.langchain.autolog() 594 595 model = create_openai_runnable() 596 original_handlers = _extract_callback_handlers(config) 597 _reset_callback_handlers(original_handlers) 598 599 input = {"product": "MLflow"} 600 if invoke_arg == "args": 601 model.invoke(input, config) 602 elif invoke_arg == "kwargs": 603 model.invoke(input, config=config) 604 elif invoke_arg is None: 605 model.invoke(input) 606 607 if async_logging_enabled: 608 mlflow.flush_trace_async_logging(terminate=True) 609 610 traces = get_traces() 611 assert len(traces) == 1 612 assert traces[0].info.status == "OK" 613 assert traces[0].data.spans[0].name == "RunnableSequence" 614 assert traces[0].data.spans[0].inputs == input 615 assert traces[0].data.spans[0].outputs == [{"role": "user", "content": "What is MLflow?"}] 616 # Original callback should not be mutated 617 handlers = _extract_callback_handlers(config) 618 assert handlers == original_handlers 619 620 # The original callback is called by the chain 621 if handlers and invoke_arg: 622 # NB: Langchain has a bug that the callback is called different times when 623 # passed by a list or a callback manager. As a workaround we only check 624 # the content of the events not the count. 625 # https://github.com/langchain-ai/langchain/issues/24642 626 assert set(handlers[0].logs) == {"chain_start", "chain_end"} 627 628 629 @pytest.mark.parametrize("invoke_arg", ["args", "kwargs", None]) 630 @pytest.mark.parametrize("config", _CONFIG_PATTERNS + _ASYNC_CONFIG_PATTERNS) 631 @pytest.mark.asyncio 632 async def test_langchain_autolog_callback_injection_in_ainvoke( 633 invoke_arg, config, async_logging_enabled 634 ): 635 mlflow.langchain.autolog() 636 637 model = create_openai_runnable() 638 original_handlers = _extract_callback_handlers(config) 639 _reset_callback_handlers(original_handlers) 640 641 input = {"product": "MLflow"} 642 if invoke_arg == "args": 643 await model.ainvoke(input, config) 644 elif invoke_arg == "kwargs": 645 await model.ainvoke(input, config=config) 646 elif invoke_arg is None: 647 await model.ainvoke(input) 648 649 if async_logging_enabled: 650 mlflow.flush_trace_async_logging(terminate=True) 651 652 traces = get_traces() 653 assert len(traces) == 1 654 assert traces[0].info.status == "OK" 655 assert traces[0].data.spans[0].name == "RunnableSequence" 656 assert traces[0].data.spans[0].inputs == input 657 assert traces[0].data.spans[0].outputs == [{"role": "user", "content": "What is MLflow?"}] 658 659 # Original callback should not be mutated 660 handlers = _extract_callback_handlers(config) 661 assert handlers == original_handlers 662 663 # The original callback is called by the chain 664 if handlers and invoke_arg: 665 # NB: Langchain has a bug that the callback is called different times when 666 # passed by a list or a callback manager. As a workaround we only check 667 # the content of the events not the count. 668 # https://github.com/langchain-ai/langchain/issues/24642 669 assert set(handlers[0].logs) == {"chain_start", "chain_end"} 670 671 672 @pytest.mark.parametrize("invoke_arg", ["args", "kwargs"]) 673 @pytest.mark.parametrize( 674 "config", 675 _CONFIG_PATTERNS 676 # list of configs are also supported for batch call 677 + [[config, config] for config in _CONFIG_PATTERNS], 678 ) 679 def test_langchain_autolog_callback_injection_in_batch(invoke_arg, config, async_logging_enabled): 680 mlflow.langchain.autolog() 681 682 model = create_openai_runnable() 683 original_handlers = _extract_callback_handlers(config) 684 _reset_callback_handlers(original_handlers) 685 686 input = {"product": "MLflow"} 687 if invoke_arg == "args": 688 model.batch([input] * 2, config) 689 elif invoke_arg == "kwargs": 690 model.batch([input] * 2, config=config) 691 elif invoke_arg is None: 692 model.batch([input] * 2) 693 694 if async_logging_enabled: 695 mlflow.flush_trace_async_logging(terminate=True) 696 697 traces = get_traces() 698 assert len(traces) == 2 699 for trace in traces: 700 assert trace.info.status == "OK" 701 assert trace.data.spans[0].name == "RunnableSequence" 702 assert trace.data.spans[0].inputs == input 703 assert trace.data.spans[0].outputs == [{"role": "user", "content": "What is MLflow?"}] 704 705 # Original callback should not be mutated 706 handlers = _extract_callback_handlers(config) 707 assert handlers == original_handlers 708 709 # The original callback is called by the chain 710 if handlers and invoke_arg: 711 for handler in handlers: 712 assert set(handler.logs) == {"chain_start", "chain_end"} 713 714 715 @skip_when_testing_trace_sdk 716 def test_tracing_source_run_in_batch(): 717 mlflow.langchain.autolog() 718 719 model = create_openai_runnable() 720 input = {"product": "MLflow"} 721 with mlflow.start_run() as run: 722 model.batch([input] * 2) 723 724 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 725 assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run.info.run_id 726 727 728 @skip_when_testing_trace_sdk 729 def test_tracing_source_run_in_pyfunc_model_predict(model_info): 730 mlflow.langchain.autolog() 731 732 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 733 with mlflow.start_run() as run: 734 pyfunc_model.predict([{"product": "MLflow"}] * 2) 735 736 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 737 assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run.info.run_id 738 739 740 @pytest.mark.parametrize("invoke_arg", ["args", "kwargs", None]) 741 @pytest.mark.parametrize( 742 "config", 743 _CONFIG_PATTERNS 744 + _ASYNC_CONFIG_PATTERNS 745 # list of configs are also supported for batch call 746 + [[config, config] for config in _CONFIG_PATTERNS + _ASYNC_CONFIG_PATTERNS], 747 ) 748 @pytest.mark.asyncio 749 async def test_langchain_autolog_callback_injection_in_abatch( 750 invoke_arg, config, async_logging_enabled 751 ): 752 mlflow.langchain.autolog() 753 754 model = create_openai_runnable() 755 original_handlers = _extract_callback_handlers(config) 756 _reset_callback_handlers(original_handlers) 757 758 input = {"product": "MLflow"} 759 if invoke_arg == "args": 760 await model.abatch([input] * 2, config) 761 elif invoke_arg == "kwargs": 762 await model.abatch([input] * 2, config=config) 763 elif invoke_arg is None: 764 await model.abatch([input] * 2) 765 766 if async_logging_enabled: 767 mlflow.flush_trace_async_logging(terminate=True) 768 769 traces = get_traces() 770 assert len(traces) == 2 771 for trace in traces: 772 assert trace.info.status == "OK" 773 assert trace.data.spans[0].name == "RunnableSequence" 774 assert trace.data.spans[0].inputs == input 775 assert trace.data.spans[0].outputs == [{"role": "user", "content": "What is MLflow?"}] 776 777 # Original callback should not be mutated 778 handlers = _extract_callback_handlers(config) 779 assert handlers == original_handlers 780 781 # The original callback is called by the chain 782 if handlers and invoke_arg: 783 for handler in handlers: 784 assert set(handler.logs) == {"chain_start", "chain_end"} 785 786 787 @pytest.mark.parametrize("invoke_arg", ["args", "kwargs", None]) 788 @pytest.mark.parametrize("config", _CONFIG_PATTERNS) 789 def test_langchain_autolog_callback_injection_in_stream(invoke_arg, config, async_logging_enabled): 790 mlflow.langchain.autolog() 791 792 model = create_openai_runnable() 793 original_handlers = _extract_callback_handlers(config) 794 _reset_callback_handlers(original_handlers) 795 796 input = {"product": "MLflow"} 797 if invoke_arg == "args": 798 list(model.stream(input, config)) 799 elif invoke_arg == "kwargs": 800 list(model.stream(input, config=config)) 801 elif invoke_arg is None: 802 list(model.stream(input)) 803 804 if async_logging_enabled: 805 mlflow.flush_trace_async_logging(terminate=True) 806 807 traces = get_traces() 808 assert len(traces) == 1 809 assert traces[0].info.status == "OK" 810 assert traces[0].data.spans[0].name == "RunnableSequence" 811 assert traces[0].data.spans[0].inputs == input 812 assert traces[0].data.spans[0].outputs == "Hello world" 813 814 # Original callback should not be mutated 815 handlers = _extract_callback_handlers(config) 816 assert handlers == original_handlers 817 818 # The original callback is called by the chain 819 if handlers and invoke_arg: 820 assert set(handlers[0].logs) == {"chain_start", "chain_end"} 821 822 823 @pytest.mark.parametrize("invoke_arg", ["args", "kwargs", None]) 824 @pytest.mark.parametrize("config", _CONFIG_PATTERNS + _ASYNC_CONFIG_PATTERNS) 825 @pytest.mark.asyncio 826 async def test_langchain_autolog_callback_injection_in_astream( 827 invoke_arg, config, async_logging_enabled 828 ): 829 mlflow.langchain.autolog() 830 831 model = create_openai_runnable() 832 original_handlers = _extract_callback_handlers(config) 833 _reset_callback_handlers(original_handlers) 834 input = {"product": "MLflow"} 835 836 async def invoke_astream(model, config): 837 if invoke_arg == "args": 838 astream = model.astream(input, config) 839 elif invoke_arg == "kwargs": 840 astream = model.astream(input, config=config) 841 elif invoke_arg is None: 842 astream = model.astream(input) 843 844 # Consume the stream 845 async for _ in astream: 846 pass 847 848 await invoke_astream(model, config) 849 850 if async_logging_enabled: 851 mlflow.flush_trace_async_logging(terminate=True) 852 853 traces = get_traces() 854 assert len(traces) == 1 855 assert traces[0].info.status == "OK" 856 assert traces[0].data.spans[0].name == "RunnableSequence" 857 assert traces[0].data.spans[0].inputs == input 858 assert traces[0].data.spans[0].outputs == "Hello world" 859 860 # Original callback should not be mutated 861 handlers = _extract_callback_handlers(config) 862 assert handlers == original_handlers 863 864 # The original callback is called by the chain 865 if handlers and invoke_arg: 866 assert set(handlers[0].logs) == {"chain_start", "chain_end"} 867 868 869 def test_langchain_autolog_produces_expected_traces_with_streaming(tmp_path, async_logging_enabled): 870 mlflow.langchain.autolog() 871 retriever, _ = create_retriever(tmp_path) 872 prompt = ChatPromptTemplate.from_template( 873 "Answer the following question based on the context: {context}\nQuestion: {question}" 874 ) 875 chat_model = create_fake_chat_model() 876 retrieval_chain = ( 877 { 878 "context": retriever, 879 "question": RunnablePassthrough(), 880 } 881 | prompt 882 | chat_model 883 | StrOutputParser() 884 ) 885 question = "What is a good name for a company that makes MLflow?" 886 list(retrieval_chain.stream(question)) 887 retrieval_chain.invoke(question) 888 889 if async_logging_enabled: 890 mlflow.flush_trace_async_logging(terminate=True) 891 892 traces = get_traces() 893 assert len(traces) == 2 894 stream_trace = traces[0] 895 invoke_trace = traces[1] 896 897 assert stream_trace.info.status == invoke_trace.info.status == TraceStatus.OK 898 assert stream_trace.data.request == invoke_trace.data.request 899 assert stream_trace.data.response == invoke_trace.data.response 900 assert len(stream_trace.data.spans) == len(invoke_trace.data.spans) 901 902 903 def test_langchain_autolog_tracing_thread_safe(async_logging_enabled): 904 mlflow.langchain.autolog() 905 906 model = create_openai_runnable() 907 908 def _invoke(): 909 # Add random sleep to simulate real LLM prediction 910 time.sleep(random.uniform(0.1, 0.5)) 911 912 model.invoke({"product": "MLflow"}) 913 914 with ThreadPoolExecutor(max_workers=8, thread_name_prefix="test-langchain-autolog") as executor: 915 futures = [executor.submit(_invoke) for _ in range(30)] 916 _ = [f.result() for f in futures] 917 918 if async_logging_enabled: 919 mlflow.flush_trace_async_logging(terminate=True) 920 921 traces = get_traces() 922 assert len(traces) == 30 923 for trace in traces: 924 assert trace.info.status == "OK" 925 assert len(trace.data.spans) == 4 926 assert trace.data.spans[0].name == "RunnableSequence" 927 928 929 @pytest.mark.asyncio 930 async def test_langchain_autolog_token_usage(mock_litellm_cost): 931 mlflow.langchain.autolog() 932 933 model = create_openai_runnable() 934 935 def _validate_token_counts(trace): 936 actual = trace.info.token_usage 937 assert actual == {"input_tokens": 9, "output_tokens": 12, "total_tokens": 21} 938 939 def _validate_model_name(trace): 940 # Find the ChatOpenAI span 941 chat_model_span = next(s for s in trace.data.spans if s.name == "ChatOpenAI") 942 assert chat_model_span.model_name == "gpt-3.5-turbo" 943 944 def _validate_cost(trace): 945 if IS_TRACING_SDK_ONLY: 946 return 947 # Find the ChatOpenAI span 948 chat_model_span = next(s for s in trace.data.spans if s.name == "ChatOpenAI") 949 assert chat_model_span.llm_cost == { 950 "input_cost": 9.0, 951 "output_cost": 24.0, 952 "total_cost": 33.0, 953 } 954 955 # Normal invoke 956 model.invoke({"product": "MLflow"}) 957 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 958 _validate_token_counts(trace) 959 _validate_model_name(trace) 960 _validate_cost(trace) 961 962 # Invoke with streaming 963 list(model.stream({"product": "MLflow"})) 964 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 965 _validate_token_counts(trace) 966 _validate_model_name(trace) 967 _validate_cost(trace) 968 969 # Async invoke 970 await model.ainvoke({"product": "MLflow"}) 971 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 972 _validate_token_counts(trace) 973 _validate_model_name(trace) 974 _validate_cost(trace) 975 976 # When both OpenAI and LangChain autologging is enabled, 977 # no duplicated token usage should be logged 978 mlflow.openai.autolog() 979 980 model.invoke({"product": "MLflow"}) 981 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 982 _validate_token_counts(trace) 983 _validate_model_name(trace) 984 _validate_cost(trace) 985 986 987 @pytest.mark.parametrize("log_traces", [True, False, None]) 988 def test_langchain_tracer_injection_for_arbitrary_runnables(log_traces, async_logging_enabled): 989 should_log_traces = log_traces is not False 990 991 if log_traces is not None: 992 mlflow.langchain.autolog(log_traces=log_traces) 993 else: 994 mlflow.langchain.autolog() 995 996 add = RunnableLambda(func=lambda x: x + 1) 997 square = RunnableLambda(func=lambda x: x**2) 998 model = RouterRunnable(runnables={"add": add, "square": square}) 999 1000 model.invoke({"key": "square", "input": 3}) 1001 1002 if async_logging_enabled and should_log_traces: 1003 mlflow.flush_trace_async_logging(terminate=True) 1004 1005 traces = get_traces() 1006 if should_log_traces: 1007 assert len(traces) == 1 1008 assert traces[0].data.spans[0].span_type == "CHAIN" 1009 else: 1010 assert len(traces) == 0 1011 1012 1013 @skip_when_testing_trace_sdk 1014 @pytest.mark.skip(reason="This test is not thread safe, please run locally") 1015 def test_set_retriever_schema_work_for_langchain_model(model_info): 1016 from mlflow.models.dependencies_schemas import DependenciesSchemasType, set_retriever_schema 1017 1018 set_retriever_schema( 1019 primary_key="primary-key", 1020 text_column="text-column", 1021 doc_uri="doc-uri", 1022 other_columns=["column1", "column2"], 1023 ) 1024 1025 mlflow.langchain.autolog() 1026 1027 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 1028 pyfunc_model.predict("MLflow") 1029 1030 traces = get_traces() 1031 assert len(traces) == 1 1032 assert DependenciesSchemasType.RETRIEVERS.value in traces[0].info.tags 1033 1034 purge_traces() 1035 1036 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 1037 list(pyfunc_model.predict_stream("MLflow")) 1038 1039 traces = get_traces() 1040 assert len(traces) == 1 1041 assert DependenciesSchemasType.RETRIEVERS.value in traces[0].info.tags 1042 1043 1044 def test_langchain_auto_tracing_work_when_langchain_parent_package_not_installed(): 1045 original_import = __import__ 1046 1047 def _mock_import(name, *args): 1048 # Allow langchain.globals and its dependencies for langchain-core 0.3.76 compatibility 1049 allowed_langchain_modules = { 1050 "langchain.globals", 1051 "langchain._api", 1052 "langchain._api.interactive_env", 1053 } 1054 if name.startswith("langchain.") and name not in allowed_langchain_modules: 1055 raise ImportError("No module named 'langchain'") 1056 return original_import(name, *args) 1057 1058 with mock.patch("builtins.__import__", side_effect=_mock_import): 1059 mlflow.langchain.autolog() 1060 1061 chain, input_example = create_runnable_sequence() 1062 assert chain.invoke(input_example) == TEST_CONTENT 1063 assert chain.invoke(input_example) == TEST_CONTENT 1064 1065 if async_logging_enabled: 1066 mlflow.flush_trace_async_logging(terminate=True) 1067 1068 traces = get_traces() 1069 assert len(traces) == 2 1070 assert all(len(trace.data.spans) == 11 for trace in traces) 1071 1072 1073 @skip_when_testing_trace_sdk 1074 def test_langchain_auto_tracing_in_serving_runnable(model_info): 1075 mlflow.langchain.autolog() 1076 1077 expected_output = '[{"role": "user", "content": "What is MLflow?"}]' 1078 databricks_request_id, predictions, trace = score_in_model_serving( 1079 model_info.model_uri, 1080 [{"product": "MLflow"}], 1081 ) 1082 1083 assert predictions == [expected_output] 1084 trace = Trace.from_dict(trace) 1085 assert trace.info.trace_id.startswith("tr-") 1086 assert trace.info.client_request_id == databricks_request_id 1087 assert trace.info.request_metadata[TRACE_SCHEMA_VERSION_KEY] == "3" 1088 spans = trace.data.spans 1089 assert len(spans) == 4 1090 1091 root_span = spans[0] 1092 assert root_span.start_time_ns // 1_000_000 == trace.info.timestamp_ms 1093 # there might be slight difference when we truncate nano seconds to milliseconds 1094 assert ( 1095 root_span.end_time_ns // 1_000_000 1096 - (trace.info.timestamp_ms + trace.info.execution_time_ms) 1097 ) <= 1 1098 assert root_span.inputs == {"product": "MLflow"} 1099 assert root_span.outputs == expected_output 1100 assert root_span.span_type == "CHAIN" 1101 1102 root_span_id = root_span.span_id 1103 child_span = spans[2] 1104 assert child_span.parent_id == root_span_id 1105 assert child_span.inputs["messages"][0]["content"] == "What is MLflow?" 1106 assert child_span.outputs["choices"][0]["message"]["content"] == expected_output 1107 assert child_span.span_type == "CHAT_MODEL" 1108 1109 1110 @pytest.mark.skipif(not IS_LANGCHAIN_v1, reason="create_agent is not supported in langchain v0") 1111 @skip_when_testing_trace_sdk 1112 def test_langchain_auto_tracing_in_serving_agent(): 1113 mlflow.langchain.autolog() 1114 1115 input_example = {"input": "What is 2 * 3?"} 1116 1117 with mlflow.start_run(): 1118 model_info = mlflow.langchain.log_model( 1119 "tests/langchain/sample_code/openai_agent.py", 1120 name="langchain_model", 1121 input_example=input_example, 1122 ) 1123 1124 databricks_request_id, response, trace_dict = score_in_model_serving( 1125 model_info.model_uri, 1126 input_example, 1127 ) 1128 1129 trace = Trace.from_dict(trace_dict) 1130 assert trace.info.trace_id.startswith("tr-") 1131 assert trace.info.client_request_id == databricks_request_id 1132 assert trace.info.status == "OK" 1133 1134 spans = trace.data.spans 1135 assert len(spans) == 7 1136 1137 root_span = spans[0] 1138 assert root_span.name == "LangGraph" 1139 assert root_span.span_type == SpanType.CHAIN 1140 assert root_span.inputs["input"] == "What is 2 * 3?" 1141 assert root_span.outputs["messages"][-1]["content"] == "The result of 2 * 3 is 6." 1142 assert root_span.start_time_ns // 1_000_000 == trace.info.timestamp_ms 1143 assert ( 1144 root_span.end_time_ns // 1_000_000 1145 - (trace.info.timestamp_ms + trace.info.execution_time_ms) 1146 ) <= 1 1147 1148 1149 def test_langchain_tracing_multi_threads(): 1150 mlflow.langchain.autolog() 1151 1152 temperatures = [(t + 1) / 10 for t in range(4)] 1153 models = [create_openai_runnable(temperature=t) for t in temperatures] 1154 1155 with ThreadPoolExecutor( 1156 max_workers=len(temperatures), thread_name_prefix="test-langchain-concurrent" 1157 ) as executor: 1158 futures = [executor.submit(models[i].invoke, {"product": "MLflow"}) for i in range(4)] 1159 for f in futures: 1160 f.result() 1161 1162 traces = get_traces() 1163 assert len(traces) == 4 1164 assert ( 1165 sorted( 1166 trace.data.spans[2].get_attribute("invocation_params")["temperature"] 1167 for trace in traces 1168 ) 1169 == temperatures 1170 ) 1171 1172 1173 @skip_when_testing_trace_sdk 1174 @pytest.mark.parametrize("func", ["invoke", "batch", "stream"]) 1175 def test_autolog_link_traces_to_loaded_model(model_infos, func): 1176 mlflow.langchain.autolog() 1177 1178 for model_info in model_infos: 1179 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1180 msg = {"product": f"{loaded_model.steps[1].temperature}_{model_info.model_id}"} 1181 if func == "invoke": 1182 loaded_model.invoke(msg) 1183 elif func == "batch": 1184 loaded_model.batch([msg]) 1185 elif func == "stream": 1186 list(loaded_model.stream(msg)) 1187 1188 traces = get_traces() 1189 assert len(traces) == len(model_infos) 1190 for trace in traces: 1191 temp = trace.data.spans[2].get_attribute("invocation_params")["temperature"] 1192 logged_temp, logged_model_id = json.loads(trace.data.request)["product"].split( 1193 "_", maxsplit=1 1194 ) 1195 assert logged_model_id is not None 1196 assert str(temp) == logged_temp 1197 assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == logged_model_id 1198 1199 1200 @skip_when_testing_trace_sdk 1201 @pytest.mark.parametrize("func", ["ainvoke", "abatch", "astream"]) 1202 @pytest.mark.asyncio 1203 async def test_autolog_link_traces_to_loaded_model_async(model_infos, func): 1204 mlflow.langchain.autolog() 1205 1206 for model_info in model_infos: 1207 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1208 msg = {"product": f"{loaded_model.steps[1].temperature}_{model_info.model_id}"} 1209 if func == "ainvoke": 1210 await loaded_model.ainvoke(msg) 1211 elif func == "abatch": 1212 await loaded_model.abatch([msg]) 1213 elif func == "astream": 1214 async for chunk in loaded_model.astream(msg): 1215 pass 1216 1217 traces = get_traces() 1218 assert len(traces) == len(model_infos) 1219 for trace in traces: 1220 temp = trace.data.spans[2].get_attribute("invocation_params")["temperature"] 1221 logged_temp, logged_model_id = json.loads(trace.data.request)["product"].split( 1222 "_", maxsplit=1 1223 ) 1224 assert logged_model_id is not None 1225 assert str(temp) == logged_temp 1226 assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == logged_model_id 1227 1228 1229 @skip_when_testing_trace_sdk 1230 def test_autolog_link_traces_to_loaded_model_pyfunc(model_infos): 1231 mlflow.langchain.autolog() 1232 1233 for model_info in model_infos: 1234 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1235 loaded_model.predict({"product": model_info.model_id}) 1236 1237 traces = get_traces() 1238 assert len(traces) == len(model_infos) 1239 for trace in traces: 1240 logged_model_id = json.loads(trace.data.request)["product"] 1241 assert logged_model_id is not None 1242 assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == logged_model_id 1243 1244 1245 @skip_when_testing_trace_sdk 1246 def test_autolog_link_traces_to_active_model(model_infos): 1247 model = mlflow.create_external_model(name="test_model") 1248 mlflow.set_active_model(model_id=model.model_id) 1249 mlflow.langchain.autolog() 1250 1251 for model_info in model_infos: 1252 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1253 loaded_model.predict({"product": model_info.model_id}) 1254 1255 traces = get_traces() 1256 assert len(traces) == len(model_infos) 1257 for trace in traces: 1258 logged_model_id = json.loads(trace.data.request)["product"] 1259 assert logged_model_id is not None 1260 assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == model.model_id 1261 assert model.model_id != logged_model_id 1262 1263 1264 @skip_when_testing_trace_sdk 1265 def test_model_loading_set_active_model_id_without_fetching_logged_model(model_info): 1266 mlflow.langchain.autolog() 1267 1268 with mock.patch("mlflow.get_logged_model", side_effect=Exception("get_logged_model failed")): 1269 loaded_model = mlflow.langchain.load_model(model_info.model_uri) 1270 loaded_model.invoke({"product": "MLflow"}) 1271 1272 traces = get_traces() 1273 assert len(traces) == 1 1274 model_id = traces[0].info.request_metadata[TraceMetadataKey.MODEL_ID] 1275 assert model_id == model_info.model_id 1276 1277 1278 @skip_when_testing_trace_sdk 1279 @pytest.mark.parametrize("log_traces", [True, False]) 1280 def test_langchain_tracing_evaluate(log_traces): 1281 from mlflow.genai import scorer 1282 1283 if log_traces: 1284 mlflow.langchain.autolog() 1285 mlflow.openai.autolog() # Our chain contains OpenAI call as well 1286 1287 chain = create_openai_runnable() 1288 1289 data = [ 1290 { 1291 "inputs": {"product": "MLflow"}, 1292 "expectations": {"expected_response": "MLflow is an open-source platform."}, 1293 }, 1294 { 1295 "inputs": {"product": "Spark"}, 1296 "expectations": {"expected_response": "Spark is a unified analytics engine."}, 1297 }, 1298 ] 1299 1300 def predict_fn(product: str) -> str: 1301 return chain.invoke({"product": product}) 1302 1303 @scorer 1304 def exact_match(outputs: str, expectations: dict[str, str]) -> bool: 1305 return outputs == expectations["expected_response"] 1306 1307 result = mlflow.genai.evaluate( 1308 predict_fn=predict_fn, 1309 data=data, 1310 scorers=[exact_match], 1311 ) 1312 assert result.metrics["exact_match/mean"] == 0.0 1313 assert result.result_df is not None 1314 1315 # Traces should be enabled automatically 1316 assert len(get_traces()) == 2 1317 for trace in get_traces(): 1318 assert len(trace.data.spans) == 5 1319 assert trace.data.spans[0].name == "RunnableSequence" 1320 assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == result.run_id 1321 assert len(trace.info.assessments) == 2 1322 1323 1324 @pytest.mark.asyncio 1325 async def test_autolog_run_tracer_inline_with_manual_traces_async(): 1326 mlflow.langchain.autolog(run_tracer_inline=True) 1327 1328 prompt = PromptTemplate( 1329 input_variables=["color"], 1330 template="What is the complementary color of {color}?", 1331 ) 1332 llm = ChatOpenAI() 1333 1334 @mlflow.trace 1335 def manual_transform(s: str): 1336 return s.replace("red", "blue") 1337 1338 chain = RunnableLambda(manual_transform) | prompt | llm | StrOutputParser() 1339 1340 @mlflow.trace(name="parent") 1341 async def run(message): 1342 return await chain.ainvoke(message) 1343 1344 response = await run("red") 1345 expected_response = '[{"role": "user", "content": "What is the complementary color of blue?"}]' 1346 assert response == expected_response 1347 1348 traces = get_traces() 1349 assert len(traces) == 1 1350 1351 trace = traces[0] 1352 spans = trace.data.spans 1353 assert spans[0].name == "parent" 1354 assert spans[1].name == "RunnableSequence" 1355 assert spans[1].parent_id == spans[0].span_id 1356 assert spans[2].name == "manual_transform" 1357 assert spans[2].parent_id == spans[1].span_id 1358 # Find and verify ChatOpenAI span has model name 1359 chat_model_span = next(s for s in spans if s.name == "ChatOpenAI") 1360 assert chat_model_span.model_name == "gpt-3.5-turbo"