/ tests / langchain / test_langchain_autolog.py
test_langchain_autolog.py
   1  import json
   2  import random
   3  import time
   4  from concurrent.futures import ThreadPoolExecutor
   5  from operator import itemgetter
   6  from typing import Any
   7  from unittest import mock
   8  
   9  import langchain_core
  10  import pytest
  11  from langchain_community.document_loaders import TextLoader
  12  from langchain_community.vectorstores import FAISS
  13  from langchain_core.callbacks.base import (
  14      AsyncCallbackHandler,
  15      BaseCallbackHandler,
  16      BaseCallbackManager,
  17  )
  18  from langchain_core.callbacks.manager import CallbackManagerForLLMRun
  19  from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
  20  from langchain_core.messages import (
  21      AIMessage,
  22      BaseMessage,
  23      HumanMessage,
  24      SystemMessage,
  25      ToolMessage,
  26  )
  27  from langchain_core.output_parsers import StrOutputParser
  28  from langchain_core.outputs import ChatGeneration, ChatResult
  29  from langchain_core.prompts import PromptTemplate
  30  from langchain_core.prompts.chat import ChatPromptTemplate
  31  from langchain_core.runnables import RunnableLambda, RunnablePassthrough
  32  from langchain_core.runnables.config import RunnableConfig
  33  from langchain_core.runnables.router import RouterRunnable
  34  from langchain_core.tools import tool
  35  from langchain_openai import ChatOpenAI
  36  from langchain_text_splitters.character import CharacterTextSplitter
  37  from packaging import version
  38  
  39  import mlflow
  40  from mlflow.entities.span import SpanType
  41  from mlflow.entities.trace import Trace
  42  from mlflow.entities.trace_status import TraceStatus
  43  from mlflow.tracing.constant import TRACE_SCHEMA_VERSION_KEY, SpanAttributeKey, TraceMetadataKey
  44  from mlflow.version import IS_TRACING_SDK_ONLY
  45  
  46  from tests.langchain.conftest import DeterministicDummyEmbeddings
  47  from tests.tracing.conftest import async_logging_enabled
  48  from tests.tracing.helper import (
  49      get_traces,
  50      purge_traces,
  51      score_in_model_serving,
  52      skip_when_testing_trace_sdk,
  53  )
  54  
  55  MODEL_DIR = "model"
  56  # The mock OpenAI endpoint simply echos the prompt back as the completion.
  57  # So the expected output will be the prompt itself.
  58  TEST_CONTENT = "What is MLflow?"
  59  
  60  _SIMPLE_MODEL_CODE_PATH = "tests/langchain/sample_code/simple_runnable.py"
  61  
  62  IS_LANGCHAIN_v1 = version.parse(langchain_core.__version__).major >= 1
  63  
  64  
  65  def create_openai_runnable(temperature=0.9):
  66      prompt = PromptTemplate(
  67          input_variables=["product"],
  68          template="What is {product}?",
  69      )
  70      llm = ChatOpenAI(temperature=temperature, stream_usage=True)
  71      return prompt | llm | StrOutputParser()
  72  
  73  
  74  @pytest.fixture
  75  def model_info():
  76      with mlflow.start_run():
  77          return mlflow.langchain.log_model(_SIMPLE_MODEL_CODE_PATH, pip_requirements=["mlflow"])
  78  
  79  
  80  @pytest.fixture
  81  def model_infos():
  82      model_infos = []
  83      for _ in range(3):
  84          with mlflow.start_run():
  85              info = mlflow.langchain.log_model(_SIMPLE_MODEL_CODE_PATH, pip_requirements=["mlflow"])
  86              model_infos.append(info)
  87      return model_infos
  88  
  89  
  90  def create_retriever(tmp_path):
  91      # Create the vector db, persist the db to a local fs folder
  92      loader = TextLoader("tests/langchain/state_of_the_union.txt")
  93      documents = loader.load()
  94      text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
  95      docs = text_splitter.split_documents(documents)
  96      embeddings = DeterministicDummyEmbeddings(size=5)
  97      db = FAISS.from_documents(docs, embeddings)
  98      persist_dir = str(tmp_path / "faiss_index")
  99      db.save_local(persist_dir)
 100      query = "What did the president say about Ketanji Brown Jackson"
 101      return db.as_retriever(), query
 102  
 103  
 104  def create_fake_chat_model():
 105      class FakeChatModel(SimpleChatModel):
 106          """Fake Chat Model wrapper for testing purposes."""
 107  
 108          def _call(
 109              self,
 110              messages: list[BaseMessage],
 111              stop: list[str] | None = None,
 112              run_manager: CallbackManagerForLLMRun | None = None,
 113              **kwargs: Any,
 114          ) -> str:
 115              return TEST_CONTENT
 116  
 117          @property
 118          def _llm_type(self) -> str:
 119              return "fake chat model"
 120  
 121      return FakeChatModel()
 122  
 123  
 124  def create_runnable_sequence():
 125      prompt_with_history_str = """
 126      Here is a history between you and a human: {chat_history}
 127  
 128      Now, please answer this question: {question}
 129      """
 130      prompt_with_history = PromptTemplate(
 131          input_variables=["chat_history", "question"], template=prompt_with_history_str
 132      )
 133  
 134      def extract_question(input):
 135          return input[-1]["content"]
 136  
 137      def extract_history(input):
 138          return input[:-1]
 139  
 140      chat_model = create_fake_chat_model()
 141      chain_with_history = (
 142          {
 143              "question": itemgetter("messages") | RunnableLambda(extract_question),
 144              "chat_history": itemgetter("messages") | RunnableLambda(extract_history),
 145          }
 146          | prompt_with_history
 147          | chat_model
 148          | StrOutputParser()
 149      )
 150      input_example = {"messages": [{"role": "user", "content": "Who owns MLflow?"}]}
 151      return chain_with_history, input_example
 152  
 153  
 154  def test_autolog_record_exception(async_logging_enabled):
 155      def always_fail(input):
 156          raise Exception("Error!")
 157  
 158      model = RunnableLambda(always_fail)
 159  
 160      mlflow.langchain.autolog()
 161  
 162      with pytest.raises(Exception, match="Error!"):
 163          model.invoke("test")
 164  
 165      if async_logging_enabled:
 166          mlflow.flush_trace_async_logging(terminate=True)
 167  
 168      traces = get_traces()
 169      assert len(traces) == 1
 170      trace = traces[0]
 171      assert trace.info.status == "ERROR"
 172      assert len(trace.data.spans) == 1
 173      assert trace.data.spans[0].name == "always_fail"
 174  
 175  
 176  def test_chat_model_autolog():
 177      mlflow.langchain.autolog()
 178      model = ChatOpenAI(model="gpt-4o-mini", temperature=0.9)
 179      messages = [
 180          SystemMessage(content="You are a helpful assistant."),
 181          HumanMessage(content="What is the weather in San Francisco?"),
 182          AIMessage(
 183              content="foo",
 184              tool_calls=[{"name": "GetWeather", "args": {"location": "San Francisco"}, "id": "123"}],
 185          ),
 186          ToolMessage(content="Weather in San Francisco is 70F.", tool_call_id="123"),
 187      ]
 188      response = model.invoke(messages)
 189  
 190      traces = get_traces()
 191      assert len(traces) == 1
 192      assert len(traces[0].data.spans) == 1
 193  
 194      span = traces[0].data.spans[0]
 195      assert span.name == "ChatOpenAI"
 196      assert span.span_type == "CHAT_MODEL"
 197      _LC_TYPE_TO_ROLE = {"human": "user", "ai": "assistant", "system": "system", "tool": "tool"}
 198      for msg, expected in zip(span.inputs["messages"], messages, strict=True):
 199          assert msg["role"] == _LC_TYPE_TO_ROLE[expected.type]
 200          assert msg["content"] == expected.content
 201      assert span.outputs["choices"][0]["message"]["content"] == response.content
 202      assert span.get_attribute("invocation_params")["model"] == "gpt-4o-mini"
 203      assert span.get_attribute("invocation_params")["temperature"] == 0.9
 204      assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "langchain"
 205      assert span.model_name == "gpt-4o-mini"
 206  
 207  
 208  @pytest.mark.parametrize(
 209      ("mime_type", "expected_format"),
 210      [
 211          ("audio/wav", "wav"),
 212          ("audio/mpeg", "mp3"),
 213      ],
 214  )
 215  def test_chat_model_autolog_audio_input_normalization(mime_type, expected_format):
 216      audio_b64 = "SGVsbG8="
 217  
 218      class AudioInputModel(BaseChatModel):
 219          def _generate(self, messages, stop=None, run_manager=None, **kwargs):
 220              return ChatResult(generations=[ChatGeneration(message=AIMessage(content="heard it"))])
 221  
 222          @property
 223          def _llm_type(self):
 224              return "audio-input-model"
 225  
 226      mlflow.langchain.autolog()
 227      model = AudioInputModel()
 228      model.invoke([
 229          HumanMessage(
 230              content=[
 231                  {"type": "text", "text": "What is this?"},
 232                  {
 233                      "type": "audio",
 234                      "source_type": "base64",
 235                      "data": audio_b64,
 236                      "mime_type": mime_type,
 237                  },
 238              ]
 239          )
 240      ])
 241  
 242      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 243      span = next(s for s in trace.data.spans if s.span_type == "CHAT_MODEL")
 244  
 245      msgs = span.inputs["messages"]
 246      audio_block = msgs[0]["content"][1]
 247      assert audio_block["type"] == "input_audio"
 248      assert audio_block["input_audio"]["format"] == expected_format
 249      attachment_uri = audio_block["input_audio"]["data"]
 250      assert attachment_uri.startswith("mlflow-attachment://")
 251      expected_mime = "mpeg" if expected_format == "mp3" else expected_format
 252      assert f"content_type=audio%2F{expected_mime}" in attachment_uri
 253  
 254  
 255  def test_chat_model_autolog_audio_output_normalization():
 256      audio_b64 = "SGVsbG8="
 257  
 258      class AudioOutputModel(BaseChatModel):
 259          def _generate(self, messages, stop=None, run_manager=None, **kwargs):
 260              ai_msg = AIMessage(
 261                  content=[
 262                      {"type": "text", "text": "Here is audio."},
 263                      {
 264                          "type": "audio",
 265                          "source_type": "base64",
 266                          "data": audio_b64,
 267                          "mime_type": "audio/wav",
 268                      },
 269                  ]
 270              )
 271              return ChatResult(generations=[ChatGeneration(message=ai_msg)])
 272  
 273          @property
 274          def _llm_type(self):
 275              return "audio-output-model"
 276  
 277      mlflow.langchain.autolog()
 278      model = AudioOutputModel()
 279      model.invoke([("human", "Give me audio")])
 280  
 281      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 282      span = next(s for s in trace.data.spans if s.span_type == "CHAT_MODEL")
 283  
 284      audio_block = span.outputs["choices"][0]["message"]["content"][1]
 285      assert audio_block["type"] == "input_audio"
 286      assert audio_block["input_audio"]["format"] == "wav"
 287      attachment_uri = audio_block["input_audio"]["data"]
 288      assert attachment_uri.startswith("mlflow-attachment://")
 289      assert "content_type=audio%2Fwav" in attachment_uri
 290  
 291  
 292  def test_chat_model_autolog_openai_audio_output_with_format():
 293      audio_b64 = "SGVsbG8="
 294  
 295      class OpenAIAudioModelWithFormat(BaseChatModel):
 296          def _generate(self, messages, stop=None, run_manager=None, **kwargs):
 297              ai_msg = AIMessage(
 298                  content="",
 299                  additional_kwargs={
 300                      "audio": {
 301                          "id": "audio_abc123",
 302                          "data": audio_b64,
 303                          "expires_at": 9999999999,
 304                          "transcript": "Yes, I am.",
 305                      }
 306                  },
 307              )
 308              return ChatResult(generations=[ChatGeneration(message=ai_msg)])
 309  
 310          @property
 311          def _llm_type(self):
 312              return "openai-audio-model"
 313  
 314          @property
 315          def _identifying_params(self):
 316              return {
 317                  "model": "gpt-4o-audio-preview",
 318                  "audio": {"voice": "alloy", "format": "wav"},
 319              }
 320  
 321      mlflow.langchain.autolog()
 322      model = OpenAIAudioModelWithFormat()
 323      model.invoke([("human", "Are you an AI?")])
 324  
 325      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 326      span = next(s for s in trace.data.spans if s.span_type == "CHAT_MODEL")
 327  
 328      content = span.outputs["choices"][0]["message"]["content"]
 329      assert isinstance(content, list)
 330      assert content[0] == {"type": "text", "text": "Yes, I am."}
 331      assert content[1]["type"] == "input_audio"
 332      attachment_uri = content[1]["input_audio"]["data"]
 333      assert attachment_uri.startswith("mlflow-attachment://")
 334      assert "content_type=audio%2Fwav" in attachment_uri
 335      assert content[1]["input_audio"]["format"] == "wav"
 336  
 337  
 338  def test_chat_model_autolog_openai_audio_transcript_fallback():
 339  
 340      class OpenAIAudioModel(BaseChatModel):
 341          def _generate(self, messages, stop=None, run_manager=None, **kwargs):
 342              ai_msg = AIMessage(
 343                  content="",
 344                  additional_kwargs={
 345                      "audio": {
 346                          "id": "audio_abc123",
 347                          "data": "SGVsbG8=",
 348                          "expires_at": 9999999999,
 349                          "transcript": "Yes, I am.",
 350                      }
 351                  },
 352              )
 353              return ChatResult(generations=[ChatGeneration(message=ai_msg)])
 354  
 355          @property
 356          def _llm_type(self):
 357              return "openai-audio-model"
 358  
 359      mlflow.langchain.autolog()
 360      model = OpenAIAudioModel()
 361      model.invoke([("human", "Are you an AI?")])
 362  
 363      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 364      span = next(s for s in trace.data.spans if s.span_type == "CHAT_MODEL")
 365  
 366      assert span.outputs["choices"][0]["message"]["content"] == "Yes, I am."
 367  
 368  
 369  def test_chat_model_autolog_openai_audio_transcript_no_override():
 370      class AudioModelWithContent(BaseChatModel):
 371          def _generate(self, messages, stop=None, run_manager=None, **kwargs):
 372              ai_msg = AIMessage(
 373                  content="I have text content.",
 374                  additional_kwargs={
 375                      "audio": {
 376                          "id": "audio_abc123",
 377                          "data": "SGVsbG8=",
 378                          "expires_at": 9999999999,
 379                          "transcript": "Different transcript.",
 380                      }
 381                  },
 382              )
 383              return ChatResult(generations=[ChatGeneration(message=ai_msg)])
 384  
 385          @property
 386          def _llm_type(self):
 387              return "audio-model-with-content"
 388  
 389      mlflow.langchain.autolog()
 390      model = AudioModelWithContent()
 391      model.invoke([("human", "Say something")])
 392  
 393      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 394      span = next(s for s in trace.data.spans if s.span_type == "CHAT_MODEL")
 395  
 396      assert span.outputs["choices"][0]["message"]["content"] == "I have text content."
 397  
 398  
 399  def test_chat_model_bind_tool_autolog():
 400      mlflow.langchain.autolog()
 401  
 402      @tool
 403      def get_weather(location: str) -> str:
 404          """Get the weather for a location."""
 405          return f"Weather in {location} is 70F."
 406  
 407      model = ChatOpenAI(model="gpt-4o-mini", temperature=0.9)
 408      model_with_tools = model.bind_tools([get_weather])
 409      model_with_tools.invoke("What is the weather in San Francisco?")
 410  
 411      traces = get_traces()
 412      assert len(traces) == 1
 413      assert len(traces[0].data.spans) == 1
 414  
 415      span = traces[0].data.spans[0]
 416      assert span.name == "ChatOpenAI"
 417      assert span.get_attribute(SpanAttributeKey.CHAT_TOOLS) == [
 418          {
 419              "type": "function",
 420              "function": {
 421                  "name": "get_weather",
 422                  "description": "Get the weather for a location.",
 423                  "parameters": {
 424                      "properties": {
 425                          "location": {
 426                              "type": "string",
 427                          }
 428                      },
 429                      "required": ["location"],
 430                      "type": "object",
 431                  },
 432              },
 433          }
 434      ]
 435      assert span.get_attribute(SpanAttributeKey.MESSAGE_FORMAT) == "langchain"
 436      assert span.model_name == "gpt-4o-mini"
 437  
 438  
 439  @pytest.mark.skipif(not IS_LANGCHAIN_v1, reason="create_agent is not supported in langchain v0")
 440  @skip_when_testing_trace_sdk
 441  def test_agent_autolog(async_logging_enabled):
 442      mlflow.langchain.autolog()
 443  
 444      # Load the agent definition (with OpenAI mock) from the sample script
 445      from langchain.agents import create_agent
 446  
 447      from tests.langchain.sample_code.openai_agent import FakeOpenAI, add, multiply
 448  
 449      model = create_agent(FakeOpenAI(), [add, multiply], system_prompt="You are a helpful assistant")
 450      prompt = "What is 2 * 3?"
 451      expected_output = "The result of 2 * 3 is 6."
 452  
 453      result = model.invoke({"messages": [HumanMessage(content=prompt)]})
 454      assert result["messages"][-1].content == expected_output
 455  
 456      if async_logging_enabled:
 457          mlflow.flush_trace_async_logging(terminate=True)
 458  
 459      traces = get_traces()
 460      assert len(traces) == 1
 461      assert len(traces[0].data.spans) == 7
 462      spans = traces[0].data.spans
 463      assert spans[0].name == "LangGraph"
 464      assert spans[0].span_type == SpanType.CHAIN
 465      assert spans[0].inputs["messages"][0]["content"] == prompt
 466      assert spans[0].outputs["messages"][-1]["content"] == expected_output
 467      llm_spans = [s for s in spans if s.span_type == SpanType.CHAT_MODEL]
 468      assert len(llm_spans) == 2
 469      assert all(s.name == "FakeOpenAI" for s in llm_spans)
 470      tool_spans = [s for s in traces[0].data.spans if s.span_type == SpanType.TOOL]
 471      assert len(tool_spans) == 1
 472      assert tool_spans[0].name == "multiply"
 473      assert tool_spans[0].inputs["a"] == 2
 474      assert tool_spans[0].inputs["b"] == 3
 475      assert tool_spans[0].outputs["content"] == "6"
 476  
 477  
 478  def test_runnable_sequence_autolog(async_logging_enabled):
 479      mlflow.langchain.autolog()
 480      chain, input_example = create_runnable_sequence()
 481      assert chain.invoke(input_example) == TEST_CONTENT
 482  
 483      if async_logging_enabled:
 484          mlflow.flush_trace_async_logging(terminate=True)
 485  
 486      traces = get_traces()
 487      assert len(traces) == 1
 488      for trace in traces:
 489          spans = {(s.name, s.span_type) for s in trace.data.spans}
 490          # Since the chain includes parallel execution, the order of some
 491          # spans is not deterministic.
 492          assert spans == {
 493              ("RunnableSequence", "CHAIN"),
 494              ("RunnableParallel<question,chat_history>", "CHAIN"),
 495              ("RunnableSequence", "CHAIN"),
 496              ("RunnableLambda", "CHAIN"),
 497              ("extract_question", "CHAIN"),
 498              ("RunnableSequence", "CHAIN"),
 499              ("RunnableLambda", "CHAIN"),
 500              ("extract_history", "CHAIN"),
 501              ("PromptTemplate", "CHAIN"),
 502              ("FakeChatModel", "CHAT_MODEL"),
 503              ("StrOutputParser", "CHAIN"),
 504          }
 505  
 506  
 507  def test_retriever_autolog(tmp_path, async_logging_enabled):
 508      mlflow.langchain.autolog()
 509      model, query = create_retriever(tmp_path)
 510      model.invoke(query)
 511  
 512      if async_logging_enabled:
 513          mlflow.flush_trace_async_logging(terminate=True)
 514  
 515      traces = get_traces()
 516      assert len(traces) == 1
 517      spans = traces[0].data.spans
 518      assert len(spans) == 1
 519      assert spans[0].name == "VectorStoreRetriever"
 520      assert spans[0].span_type == "RETRIEVER"
 521      assert spans[0].inputs == query
 522      assert spans[0].outputs[0]["metadata"] == {"source": "tests/langchain/state_of_the_union.txt"}
 523  
 524  
 525  class CustomCallbackHandler(BaseCallbackHandler):
 526      def __init__(self):
 527          self.logs = []
 528  
 529      def on_chain_start(
 530          self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
 531      ) -> None:
 532          self.logs.append("chain_start")
 533  
 534      def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
 535          self.logs.append("chain_end")
 536  
 537  
 538  class AsyncCustomCallbackHandler(AsyncCallbackHandler):
 539      def __init__(self):
 540          self.logs = []
 541  
 542      async def on_chain_start(
 543          self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
 544      ) -> None:
 545          self.logs.append("chain_start")
 546  
 547      async def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
 548          self.logs.append("chain_end")
 549  
 550  
 551  _CONFIG_PATTERNS = [
 552      # Config with no user callbacks
 553      RunnableConfig(max_concurrency=1),
 554      RunnableConfig(callbacks=None),
 555      # With user callbacks
 556      RunnableConfig(callbacks=[CustomCallbackHandler()]),
 557      RunnableConfig(callbacks=BaseCallbackManager([CustomCallbackHandler()])),
 558  ]
 559  
 560  _ASYNC_CONFIG_PATTERNS = [
 561      RunnableConfig(callbacks=[AsyncCustomCallbackHandler()]),
 562      RunnableConfig(callbacks=BaseCallbackManager([AsyncCustomCallbackHandler()])),
 563  ]
 564  
 565  
 566  def _reset_callback_handlers(handlers):
 567      if handlers:
 568          for handler in handlers:
 569              handler.logs = []
 570  
 571  
 572  def _extract_callback_handlers(config) -> list[BaseCallbackHandler] | None:
 573      if isinstance(config, list):
 574          callbacks = []
 575          for c in config:
 576              if callbacks_in_c := _extract_callback_handlers(c):
 577                  callbacks.extend(callbacks_in_c)
 578          return callbacks
 579      # RunnableConfig is also a dict
 580      elif isinstance(config, dict) and "callbacks" in config:
 581          callbacks = config["callbacks"]
 582          if isinstance(callbacks, BaseCallbackManager):
 583              return callbacks.handlers
 584          else:
 585              return callbacks
 586      else:
 587          return None
 588  
 589  
 590  @pytest.mark.parametrize("invoke_arg", ["args", "kwargs", None])
 591  @pytest.mark.parametrize("config", _CONFIG_PATTERNS)
 592  def test_langchain_autolog_callback_injection_in_invoke(invoke_arg, config, async_logging_enabled):
 593      mlflow.langchain.autolog()
 594  
 595      model = create_openai_runnable()
 596      original_handlers = _extract_callback_handlers(config)
 597      _reset_callback_handlers(original_handlers)
 598  
 599      input = {"product": "MLflow"}
 600      if invoke_arg == "args":
 601          model.invoke(input, config)
 602      elif invoke_arg == "kwargs":
 603          model.invoke(input, config=config)
 604      elif invoke_arg is None:
 605          model.invoke(input)
 606  
 607      if async_logging_enabled:
 608          mlflow.flush_trace_async_logging(terminate=True)
 609  
 610      traces = get_traces()
 611      assert len(traces) == 1
 612      assert traces[0].info.status == "OK"
 613      assert traces[0].data.spans[0].name == "RunnableSequence"
 614      assert traces[0].data.spans[0].inputs == input
 615      assert traces[0].data.spans[0].outputs == [{"role": "user", "content": "What is MLflow?"}]
 616      # Original callback should not be mutated
 617      handlers = _extract_callback_handlers(config)
 618      assert handlers == original_handlers
 619  
 620      # The original callback is called by the chain
 621      if handlers and invoke_arg:
 622          # NB: Langchain has a bug that the callback is called different times when
 623          # passed by a list or a callback manager. As a workaround we only check
 624          # the content of the events not the count.
 625          # https://github.com/langchain-ai/langchain/issues/24642
 626          assert set(handlers[0].logs) == {"chain_start", "chain_end"}
 627  
 628  
 629  @pytest.mark.parametrize("invoke_arg", ["args", "kwargs", None])
 630  @pytest.mark.parametrize("config", _CONFIG_PATTERNS + _ASYNC_CONFIG_PATTERNS)
 631  @pytest.mark.asyncio
 632  async def test_langchain_autolog_callback_injection_in_ainvoke(
 633      invoke_arg, config, async_logging_enabled
 634  ):
 635      mlflow.langchain.autolog()
 636  
 637      model = create_openai_runnable()
 638      original_handlers = _extract_callback_handlers(config)
 639      _reset_callback_handlers(original_handlers)
 640  
 641      input = {"product": "MLflow"}
 642      if invoke_arg == "args":
 643          await model.ainvoke(input, config)
 644      elif invoke_arg == "kwargs":
 645          await model.ainvoke(input, config=config)
 646      elif invoke_arg is None:
 647          await model.ainvoke(input)
 648  
 649      if async_logging_enabled:
 650          mlflow.flush_trace_async_logging(terminate=True)
 651  
 652      traces = get_traces()
 653      assert len(traces) == 1
 654      assert traces[0].info.status == "OK"
 655      assert traces[0].data.spans[0].name == "RunnableSequence"
 656      assert traces[0].data.spans[0].inputs == input
 657      assert traces[0].data.spans[0].outputs == [{"role": "user", "content": "What is MLflow?"}]
 658  
 659      # Original callback should not be mutated
 660      handlers = _extract_callback_handlers(config)
 661      assert handlers == original_handlers
 662  
 663      # The original callback is called by the chain
 664      if handlers and invoke_arg:
 665          # NB: Langchain has a bug that the callback is called different times when
 666          # passed by a list or a callback manager. As a workaround we only check
 667          # the content of the events not the count.
 668          # https://github.com/langchain-ai/langchain/issues/24642
 669          assert set(handlers[0].logs) == {"chain_start", "chain_end"}
 670  
 671  
 672  @pytest.mark.parametrize("invoke_arg", ["args", "kwargs"])
 673  @pytest.mark.parametrize(
 674      "config",
 675      _CONFIG_PATTERNS
 676      # list of configs are also supported for batch call
 677      + [[config, config] for config in _CONFIG_PATTERNS],
 678  )
 679  def test_langchain_autolog_callback_injection_in_batch(invoke_arg, config, async_logging_enabled):
 680      mlflow.langchain.autolog()
 681  
 682      model = create_openai_runnable()
 683      original_handlers = _extract_callback_handlers(config)
 684      _reset_callback_handlers(original_handlers)
 685  
 686      input = {"product": "MLflow"}
 687      if invoke_arg == "args":
 688          model.batch([input] * 2, config)
 689      elif invoke_arg == "kwargs":
 690          model.batch([input] * 2, config=config)
 691      elif invoke_arg is None:
 692          model.batch([input] * 2)
 693  
 694      if async_logging_enabled:
 695          mlflow.flush_trace_async_logging(terminate=True)
 696  
 697      traces = get_traces()
 698      assert len(traces) == 2
 699      for trace in traces:
 700          assert trace.info.status == "OK"
 701          assert trace.data.spans[0].name == "RunnableSequence"
 702          assert trace.data.spans[0].inputs == input
 703          assert trace.data.spans[0].outputs == [{"role": "user", "content": "What is MLflow?"}]
 704  
 705      # Original callback should not be mutated
 706      handlers = _extract_callback_handlers(config)
 707      assert handlers == original_handlers
 708  
 709      # The original callback is called by the chain
 710      if handlers and invoke_arg:
 711          for handler in handlers:
 712              assert set(handler.logs) == {"chain_start", "chain_end"}
 713  
 714  
 715  @skip_when_testing_trace_sdk
 716  def test_tracing_source_run_in_batch():
 717      mlflow.langchain.autolog()
 718  
 719      model = create_openai_runnable()
 720      input = {"product": "MLflow"}
 721      with mlflow.start_run() as run:
 722          model.batch([input] * 2)
 723  
 724      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 725      assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run.info.run_id
 726  
 727  
 728  @skip_when_testing_trace_sdk
 729  def test_tracing_source_run_in_pyfunc_model_predict(model_info):
 730      mlflow.langchain.autolog()
 731  
 732      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
 733      with mlflow.start_run() as run:
 734          pyfunc_model.predict([{"product": "MLflow"}] * 2)
 735  
 736      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 737      assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run.info.run_id
 738  
 739  
 740  @pytest.mark.parametrize("invoke_arg", ["args", "kwargs", None])
 741  @pytest.mark.parametrize(
 742      "config",
 743      _CONFIG_PATTERNS
 744      + _ASYNC_CONFIG_PATTERNS
 745      # list of configs are also supported for batch call
 746      + [[config, config] for config in _CONFIG_PATTERNS + _ASYNC_CONFIG_PATTERNS],
 747  )
 748  @pytest.mark.asyncio
 749  async def test_langchain_autolog_callback_injection_in_abatch(
 750      invoke_arg, config, async_logging_enabled
 751  ):
 752      mlflow.langchain.autolog()
 753  
 754      model = create_openai_runnable()
 755      original_handlers = _extract_callback_handlers(config)
 756      _reset_callback_handlers(original_handlers)
 757  
 758      input = {"product": "MLflow"}
 759      if invoke_arg == "args":
 760          await model.abatch([input] * 2, config)
 761      elif invoke_arg == "kwargs":
 762          await model.abatch([input] * 2, config=config)
 763      elif invoke_arg is None:
 764          await model.abatch([input] * 2)
 765  
 766      if async_logging_enabled:
 767          mlflow.flush_trace_async_logging(terminate=True)
 768  
 769      traces = get_traces()
 770      assert len(traces) == 2
 771      for trace in traces:
 772          assert trace.info.status == "OK"
 773          assert trace.data.spans[0].name == "RunnableSequence"
 774          assert trace.data.spans[0].inputs == input
 775          assert trace.data.spans[0].outputs == [{"role": "user", "content": "What is MLflow?"}]
 776  
 777      # Original callback should not be mutated
 778      handlers = _extract_callback_handlers(config)
 779      assert handlers == original_handlers
 780  
 781      # The original callback is called by the chain
 782      if handlers and invoke_arg:
 783          for handler in handlers:
 784              assert set(handler.logs) == {"chain_start", "chain_end"}
 785  
 786  
 787  @pytest.mark.parametrize("invoke_arg", ["args", "kwargs", None])
 788  @pytest.mark.parametrize("config", _CONFIG_PATTERNS)
 789  def test_langchain_autolog_callback_injection_in_stream(invoke_arg, config, async_logging_enabled):
 790      mlflow.langchain.autolog()
 791  
 792      model = create_openai_runnable()
 793      original_handlers = _extract_callback_handlers(config)
 794      _reset_callback_handlers(original_handlers)
 795  
 796      input = {"product": "MLflow"}
 797      if invoke_arg == "args":
 798          list(model.stream(input, config))
 799      elif invoke_arg == "kwargs":
 800          list(model.stream(input, config=config))
 801      elif invoke_arg is None:
 802          list(model.stream(input))
 803  
 804      if async_logging_enabled:
 805          mlflow.flush_trace_async_logging(terminate=True)
 806  
 807      traces = get_traces()
 808      assert len(traces) == 1
 809      assert traces[0].info.status == "OK"
 810      assert traces[0].data.spans[0].name == "RunnableSequence"
 811      assert traces[0].data.spans[0].inputs == input
 812      assert traces[0].data.spans[0].outputs == "Hello world"
 813  
 814      # Original callback should not be mutated
 815      handlers = _extract_callback_handlers(config)
 816      assert handlers == original_handlers
 817  
 818      # The original callback is called by the chain
 819      if handlers and invoke_arg:
 820          assert set(handlers[0].logs) == {"chain_start", "chain_end"}
 821  
 822  
 823  @pytest.mark.parametrize("invoke_arg", ["args", "kwargs", None])
 824  @pytest.mark.parametrize("config", _CONFIG_PATTERNS + _ASYNC_CONFIG_PATTERNS)
 825  @pytest.mark.asyncio
 826  async def test_langchain_autolog_callback_injection_in_astream(
 827      invoke_arg, config, async_logging_enabled
 828  ):
 829      mlflow.langchain.autolog()
 830  
 831      model = create_openai_runnable()
 832      original_handlers = _extract_callback_handlers(config)
 833      _reset_callback_handlers(original_handlers)
 834      input = {"product": "MLflow"}
 835  
 836      async def invoke_astream(model, config):
 837          if invoke_arg == "args":
 838              astream = model.astream(input, config)
 839          elif invoke_arg == "kwargs":
 840              astream = model.astream(input, config=config)
 841          elif invoke_arg is None:
 842              astream = model.astream(input)
 843  
 844          # Consume the stream
 845          async for _ in astream:
 846              pass
 847  
 848      await invoke_astream(model, config)
 849  
 850      if async_logging_enabled:
 851          mlflow.flush_trace_async_logging(terminate=True)
 852  
 853      traces = get_traces()
 854      assert len(traces) == 1
 855      assert traces[0].info.status == "OK"
 856      assert traces[0].data.spans[0].name == "RunnableSequence"
 857      assert traces[0].data.spans[0].inputs == input
 858      assert traces[0].data.spans[0].outputs == "Hello world"
 859  
 860      # Original callback should not be mutated
 861      handlers = _extract_callback_handlers(config)
 862      assert handlers == original_handlers
 863  
 864      # The original callback is called by the chain
 865      if handlers and invoke_arg:
 866          assert set(handlers[0].logs) == {"chain_start", "chain_end"}
 867  
 868  
 869  def test_langchain_autolog_produces_expected_traces_with_streaming(tmp_path, async_logging_enabled):
 870      mlflow.langchain.autolog()
 871      retriever, _ = create_retriever(tmp_path)
 872      prompt = ChatPromptTemplate.from_template(
 873          "Answer the following question based on the context: {context}\nQuestion: {question}"
 874      )
 875      chat_model = create_fake_chat_model()
 876      retrieval_chain = (
 877          {
 878              "context": retriever,
 879              "question": RunnablePassthrough(),
 880          }
 881          | prompt
 882          | chat_model
 883          | StrOutputParser()
 884      )
 885      question = "What is a good name for a company that makes MLflow?"
 886      list(retrieval_chain.stream(question))
 887      retrieval_chain.invoke(question)
 888  
 889      if async_logging_enabled:
 890          mlflow.flush_trace_async_logging(terminate=True)
 891  
 892      traces = get_traces()
 893      assert len(traces) == 2
 894      stream_trace = traces[0]
 895      invoke_trace = traces[1]
 896  
 897      assert stream_trace.info.status == invoke_trace.info.status == TraceStatus.OK
 898      assert stream_trace.data.request == invoke_trace.data.request
 899      assert stream_trace.data.response == invoke_trace.data.response
 900      assert len(stream_trace.data.spans) == len(invoke_trace.data.spans)
 901  
 902  
 903  def test_langchain_autolog_tracing_thread_safe(async_logging_enabled):
 904      mlflow.langchain.autolog()
 905  
 906      model = create_openai_runnable()
 907  
 908      def _invoke():
 909          # Add random sleep to simulate real LLM prediction
 910          time.sleep(random.uniform(0.1, 0.5))
 911  
 912          model.invoke({"product": "MLflow"})
 913  
 914      with ThreadPoolExecutor(max_workers=8, thread_name_prefix="test-langchain-autolog") as executor:
 915          futures = [executor.submit(_invoke) for _ in range(30)]
 916          _ = [f.result() for f in futures]
 917  
 918      if async_logging_enabled:
 919          mlflow.flush_trace_async_logging(terminate=True)
 920  
 921      traces = get_traces()
 922      assert len(traces) == 30
 923      for trace in traces:
 924          assert trace.info.status == "OK"
 925          assert len(trace.data.spans) == 4
 926          assert trace.data.spans[0].name == "RunnableSequence"
 927  
 928  
 929  @pytest.mark.asyncio
 930  async def test_langchain_autolog_token_usage(mock_litellm_cost):
 931      mlflow.langchain.autolog()
 932  
 933      model = create_openai_runnable()
 934  
 935      def _validate_token_counts(trace):
 936          actual = trace.info.token_usage
 937          assert actual == {"input_tokens": 9, "output_tokens": 12, "total_tokens": 21}
 938  
 939      def _validate_model_name(trace):
 940          # Find the ChatOpenAI span
 941          chat_model_span = next(s for s in trace.data.spans if s.name == "ChatOpenAI")
 942          assert chat_model_span.model_name == "gpt-3.5-turbo"
 943  
 944      def _validate_cost(trace):
 945          if IS_TRACING_SDK_ONLY:
 946              return
 947          # Find the ChatOpenAI span
 948          chat_model_span = next(s for s in trace.data.spans if s.name == "ChatOpenAI")
 949          assert chat_model_span.llm_cost == {
 950              "input_cost": 9.0,
 951              "output_cost": 24.0,
 952              "total_cost": 33.0,
 953          }
 954  
 955      # Normal invoke
 956      model.invoke({"product": "MLflow"})
 957      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 958      _validate_token_counts(trace)
 959      _validate_model_name(trace)
 960      _validate_cost(trace)
 961  
 962      # Invoke with streaming
 963      list(model.stream({"product": "MLflow"}))
 964      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 965      _validate_token_counts(trace)
 966      _validate_model_name(trace)
 967      _validate_cost(trace)
 968  
 969      # Async invoke
 970      await model.ainvoke({"product": "MLflow"})
 971      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 972      _validate_token_counts(trace)
 973      _validate_model_name(trace)
 974      _validate_cost(trace)
 975  
 976      # When both OpenAI and LangChain autologging is enabled,
 977      # no duplicated token usage should be logged
 978      mlflow.openai.autolog()
 979  
 980      model.invoke({"product": "MLflow"})
 981      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 982      _validate_token_counts(trace)
 983      _validate_model_name(trace)
 984      _validate_cost(trace)
 985  
 986  
 987  @pytest.mark.parametrize("log_traces", [True, False, None])
 988  def test_langchain_tracer_injection_for_arbitrary_runnables(log_traces, async_logging_enabled):
 989      should_log_traces = log_traces is not False
 990  
 991      if log_traces is not None:
 992          mlflow.langchain.autolog(log_traces=log_traces)
 993      else:
 994          mlflow.langchain.autolog()
 995  
 996      add = RunnableLambda(func=lambda x: x + 1)
 997      square = RunnableLambda(func=lambda x: x**2)
 998      model = RouterRunnable(runnables={"add": add, "square": square})
 999  
1000      model.invoke({"key": "square", "input": 3})
1001  
1002      if async_logging_enabled and should_log_traces:
1003          mlflow.flush_trace_async_logging(terminate=True)
1004  
1005      traces = get_traces()
1006      if should_log_traces:
1007          assert len(traces) == 1
1008          assert traces[0].data.spans[0].span_type == "CHAIN"
1009      else:
1010          assert len(traces) == 0
1011  
1012  
1013  @skip_when_testing_trace_sdk
1014  @pytest.mark.skip(reason="This test is not thread safe, please run locally")
1015  def test_set_retriever_schema_work_for_langchain_model(model_info):
1016      from mlflow.models.dependencies_schemas import DependenciesSchemasType, set_retriever_schema
1017  
1018      set_retriever_schema(
1019          primary_key="primary-key",
1020          text_column="text-column",
1021          doc_uri="doc-uri",
1022          other_columns=["column1", "column2"],
1023      )
1024  
1025      mlflow.langchain.autolog()
1026  
1027      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
1028      pyfunc_model.predict("MLflow")
1029  
1030      traces = get_traces()
1031      assert len(traces) == 1
1032      assert DependenciesSchemasType.RETRIEVERS.value in traces[0].info.tags
1033  
1034      purge_traces()
1035  
1036      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
1037      list(pyfunc_model.predict_stream("MLflow"))
1038  
1039      traces = get_traces()
1040      assert len(traces) == 1
1041      assert DependenciesSchemasType.RETRIEVERS.value in traces[0].info.tags
1042  
1043  
1044  def test_langchain_auto_tracing_work_when_langchain_parent_package_not_installed():
1045      original_import = __import__
1046  
1047      def _mock_import(name, *args):
1048          # Allow langchain.globals and its dependencies for langchain-core 0.3.76 compatibility
1049          allowed_langchain_modules = {
1050              "langchain.globals",
1051              "langchain._api",
1052              "langchain._api.interactive_env",
1053          }
1054          if name.startswith("langchain.") and name not in allowed_langchain_modules:
1055              raise ImportError("No module named 'langchain'")
1056          return original_import(name, *args)
1057  
1058      with mock.patch("builtins.__import__", side_effect=_mock_import):
1059          mlflow.langchain.autolog()
1060  
1061          chain, input_example = create_runnable_sequence()
1062          assert chain.invoke(input_example) == TEST_CONTENT
1063          assert chain.invoke(input_example) == TEST_CONTENT
1064  
1065          if async_logging_enabled:
1066              mlflow.flush_trace_async_logging(terminate=True)
1067  
1068          traces = get_traces()
1069          assert len(traces) == 2
1070          assert all(len(trace.data.spans) == 11 for trace in traces)
1071  
1072  
1073  @skip_when_testing_trace_sdk
1074  def test_langchain_auto_tracing_in_serving_runnable(model_info):
1075      mlflow.langchain.autolog()
1076  
1077      expected_output = '[{"role": "user", "content": "What is MLflow?"}]'
1078      databricks_request_id, predictions, trace = score_in_model_serving(
1079          model_info.model_uri,
1080          [{"product": "MLflow"}],
1081      )
1082  
1083      assert predictions == [expected_output]
1084      trace = Trace.from_dict(trace)
1085      assert trace.info.trace_id.startswith("tr-")
1086      assert trace.info.client_request_id == databricks_request_id
1087      assert trace.info.request_metadata[TRACE_SCHEMA_VERSION_KEY] == "3"
1088      spans = trace.data.spans
1089      assert len(spans) == 4
1090  
1091      root_span = spans[0]
1092      assert root_span.start_time_ns // 1_000_000 == trace.info.timestamp_ms
1093      # there might be slight difference when we truncate nano seconds to milliseconds
1094      assert (
1095          root_span.end_time_ns // 1_000_000
1096          - (trace.info.timestamp_ms + trace.info.execution_time_ms)
1097      ) <= 1
1098      assert root_span.inputs == {"product": "MLflow"}
1099      assert root_span.outputs == expected_output
1100      assert root_span.span_type == "CHAIN"
1101  
1102      root_span_id = root_span.span_id
1103      child_span = spans[2]
1104      assert child_span.parent_id == root_span_id
1105      assert child_span.inputs["messages"][0]["content"] == "What is MLflow?"
1106      assert child_span.outputs["choices"][0]["message"]["content"] == expected_output
1107      assert child_span.span_type == "CHAT_MODEL"
1108  
1109  
1110  @pytest.mark.skipif(not IS_LANGCHAIN_v1, reason="create_agent is not supported in langchain v0")
1111  @skip_when_testing_trace_sdk
1112  def test_langchain_auto_tracing_in_serving_agent():
1113      mlflow.langchain.autolog()
1114  
1115      input_example = {"input": "What is 2 * 3?"}
1116  
1117      with mlflow.start_run():
1118          model_info = mlflow.langchain.log_model(
1119              "tests/langchain/sample_code/openai_agent.py",
1120              name="langchain_model",
1121              input_example=input_example,
1122          )
1123  
1124      databricks_request_id, response, trace_dict = score_in_model_serving(
1125          model_info.model_uri,
1126          input_example,
1127      )
1128  
1129      trace = Trace.from_dict(trace_dict)
1130      assert trace.info.trace_id.startswith("tr-")
1131      assert trace.info.client_request_id == databricks_request_id
1132      assert trace.info.status == "OK"
1133  
1134      spans = trace.data.spans
1135      assert len(spans) == 7
1136  
1137      root_span = spans[0]
1138      assert root_span.name == "LangGraph"
1139      assert root_span.span_type == SpanType.CHAIN
1140      assert root_span.inputs["input"] == "What is 2 * 3?"
1141      assert root_span.outputs["messages"][-1]["content"] == "The result of 2 * 3 is 6."
1142      assert root_span.start_time_ns // 1_000_000 == trace.info.timestamp_ms
1143      assert (
1144          root_span.end_time_ns // 1_000_000
1145          - (trace.info.timestamp_ms + trace.info.execution_time_ms)
1146      ) <= 1
1147  
1148  
1149  def test_langchain_tracing_multi_threads():
1150      mlflow.langchain.autolog()
1151  
1152      temperatures = [(t + 1) / 10 for t in range(4)]
1153      models = [create_openai_runnable(temperature=t) for t in temperatures]
1154  
1155      with ThreadPoolExecutor(
1156          max_workers=len(temperatures), thread_name_prefix="test-langchain-concurrent"
1157      ) as executor:
1158          futures = [executor.submit(models[i].invoke, {"product": "MLflow"}) for i in range(4)]
1159          for f in futures:
1160              f.result()
1161  
1162      traces = get_traces()
1163      assert len(traces) == 4
1164      assert (
1165          sorted(
1166              trace.data.spans[2].get_attribute("invocation_params")["temperature"]
1167              for trace in traces
1168          )
1169          == temperatures
1170      )
1171  
1172  
1173  @skip_when_testing_trace_sdk
1174  @pytest.mark.parametrize("func", ["invoke", "batch", "stream"])
1175  def test_autolog_link_traces_to_loaded_model(model_infos, func):
1176      mlflow.langchain.autolog()
1177  
1178      for model_info in model_infos:
1179          loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1180          msg = {"product": f"{loaded_model.steps[1].temperature}_{model_info.model_id}"}
1181          if func == "invoke":
1182              loaded_model.invoke(msg)
1183          elif func == "batch":
1184              loaded_model.batch([msg])
1185          elif func == "stream":
1186              list(loaded_model.stream(msg))
1187  
1188      traces = get_traces()
1189      assert len(traces) == len(model_infos)
1190      for trace in traces:
1191          temp = trace.data.spans[2].get_attribute("invocation_params")["temperature"]
1192          logged_temp, logged_model_id = json.loads(trace.data.request)["product"].split(
1193              "_", maxsplit=1
1194          )
1195          assert logged_model_id is not None
1196          assert str(temp) == logged_temp
1197          assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == logged_model_id
1198  
1199  
1200  @skip_when_testing_trace_sdk
1201  @pytest.mark.parametrize("func", ["ainvoke", "abatch", "astream"])
1202  @pytest.mark.asyncio
1203  async def test_autolog_link_traces_to_loaded_model_async(model_infos, func):
1204      mlflow.langchain.autolog()
1205  
1206      for model_info in model_infos:
1207          loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1208          msg = {"product": f"{loaded_model.steps[1].temperature}_{model_info.model_id}"}
1209          if func == "ainvoke":
1210              await loaded_model.ainvoke(msg)
1211          elif func == "abatch":
1212              await loaded_model.abatch([msg])
1213          elif func == "astream":
1214              async for chunk in loaded_model.astream(msg):
1215                  pass
1216  
1217      traces = get_traces()
1218      assert len(traces) == len(model_infos)
1219      for trace in traces:
1220          temp = trace.data.spans[2].get_attribute("invocation_params")["temperature"]
1221          logged_temp, logged_model_id = json.loads(trace.data.request)["product"].split(
1222              "_", maxsplit=1
1223          )
1224          assert logged_model_id is not None
1225          assert str(temp) == logged_temp
1226          assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == logged_model_id
1227  
1228  
1229  @skip_when_testing_trace_sdk
1230  def test_autolog_link_traces_to_loaded_model_pyfunc(model_infos):
1231      mlflow.langchain.autolog()
1232  
1233      for model_info in model_infos:
1234          loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1235          loaded_model.predict({"product": model_info.model_id})
1236  
1237      traces = get_traces()
1238      assert len(traces) == len(model_infos)
1239      for trace in traces:
1240          logged_model_id = json.loads(trace.data.request)["product"]
1241          assert logged_model_id is not None
1242          assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == logged_model_id
1243  
1244  
1245  @skip_when_testing_trace_sdk
1246  def test_autolog_link_traces_to_active_model(model_infos):
1247      model = mlflow.create_external_model(name="test_model")
1248      mlflow.set_active_model(model_id=model.model_id)
1249      mlflow.langchain.autolog()
1250  
1251      for model_info in model_infos:
1252          loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1253          loaded_model.predict({"product": model_info.model_id})
1254  
1255      traces = get_traces()
1256      assert len(traces) == len(model_infos)
1257      for trace in traces:
1258          logged_model_id = json.loads(trace.data.request)["product"]
1259          assert logged_model_id is not None
1260          assert trace.info.request_metadata[TraceMetadataKey.MODEL_ID] == model.model_id
1261          assert model.model_id != logged_model_id
1262  
1263  
1264  @skip_when_testing_trace_sdk
1265  def test_model_loading_set_active_model_id_without_fetching_logged_model(model_info):
1266      mlflow.langchain.autolog()
1267  
1268      with mock.patch("mlflow.get_logged_model", side_effect=Exception("get_logged_model failed")):
1269          loaded_model = mlflow.langchain.load_model(model_info.model_uri)
1270      loaded_model.invoke({"product": "MLflow"})
1271  
1272      traces = get_traces()
1273      assert len(traces) == 1
1274      model_id = traces[0].info.request_metadata[TraceMetadataKey.MODEL_ID]
1275      assert model_id == model_info.model_id
1276  
1277  
1278  @skip_when_testing_trace_sdk
1279  @pytest.mark.parametrize("log_traces", [True, False])
1280  def test_langchain_tracing_evaluate(log_traces):
1281      from mlflow.genai import scorer
1282  
1283      if log_traces:
1284          mlflow.langchain.autolog()
1285          mlflow.openai.autolog()  # Our chain contains OpenAI call as well
1286  
1287      chain = create_openai_runnable()
1288  
1289      data = [
1290          {
1291              "inputs": {"product": "MLflow"},
1292              "expectations": {"expected_response": "MLflow is an open-source platform."},
1293          },
1294          {
1295              "inputs": {"product": "Spark"},
1296              "expectations": {"expected_response": "Spark is a unified analytics engine."},
1297          },
1298      ]
1299  
1300      def predict_fn(product: str) -> str:
1301          return chain.invoke({"product": product})
1302  
1303      @scorer
1304      def exact_match(outputs: str, expectations: dict[str, str]) -> bool:
1305          return outputs == expectations["expected_response"]
1306  
1307      result = mlflow.genai.evaluate(
1308          predict_fn=predict_fn,
1309          data=data,
1310          scorers=[exact_match],
1311      )
1312      assert result.metrics["exact_match/mean"] == 0.0
1313      assert result.result_df is not None
1314  
1315      # Traces should be enabled automatically
1316      assert len(get_traces()) == 2
1317      for trace in get_traces():
1318          assert len(trace.data.spans) == 5
1319          assert trace.data.spans[0].name == "RunnableSequence"
1320          assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == result.run_id
1321          assert len(trace.info.assessments) == 2
1322  
1323  
1324  @pytest.mark.asyncio
1325  async def test_autolog_run_tracer_inline_with_manual_traces_async():
1326      mlflow.langchain.autolog(run_tracer_inline=True)
1327  
1328      prompt = PromptTemplate(
1329          input_variables=["color"],
1330          template="What is the complementary color of {color}?",
1331      )
1332      llm = ChatOpenAI()
1333  
1334      @mlflow.trace
1335      def manual_transform(s: str):
1336          return s.replace("red", "blue")
1337  
1338      chain = RunnableLambda(manual_transform) | prompt | llm | StrOutputParser()
1339  
1340      @mlflow.trace(name="parent")
1341      async def run(message):
1342          return await chain.ainvoke(message)
1343  
1344      response = await run("red")
1345      expected_response = '[{"role": "user", "content": "What is the complementary color of blue?"}]'
1346      assert response == expected_response
1347  
1348      traces = get_traces()
1349      assert len(traces) == 1
1350  
1351      trace = traces[0]
1352      spans = trace.data.spans
1353      assert spans[0].name == "parent"
1354      assert spans[1].name == "RunnableSequence"
1355      assert spans[1].parent_id == spans[0].span_id
1356      assert spans[2].name == "manual_transform"
1357      assert spans[2].parent_id == spans[1].span_id
1358      # Find and verify ChatOpenAI span has model name
1359      chat_model_span = next(s for s in spans if s.name == "ChatOpenAI")
1360      assert chat_model_span.model_name == "gpt-3.5-turbo"