/ tests / langchain / test_langchain_tracer.py
test_langchain_tracer.py
  1  import random
  2  import time
  3  import uuid
  4  from concurrent.futures import ThreadPoolExecutor
  5  from typing import Any
  6  from unittest.mock import MagicMock
  7  
  8  import pydantic
  9  import pytest
 10  from langchain_community.document_loaders import TextLoader
 11  from langchain_community.embeddings import FakeEmbeddings
 12  from langchain_community.vectorstores import FAISS
 13  from langchain_core.documents import Document
 14  from langchain_core.language_models.chat_models import SimpleChatModel
 15  from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
 16  from langchain_core.output_parsers.string import StrOutputParser
 17  from langchain_core.outputs import LLMResult
 18  from langchain_core.prompts import PromptTemplate
 19  from langchain_core.prompts.chat import SystemMessagePromptTemplate
 20  from langchain_core.runnables import RunnableLambda
 21  from langchain_core.tools import tool
 22  from langchain_openai import ChatOpenAI
 23  from langchain_text_splitters.character import CharacterTextSplitter
 24  
 25  import mlflow
 26  from mlflow.entities import Document as MlflowDocument
 27  from mlflow.entities import Trace
 28  from mlflow.entities.span_event import SpanEvent
 29  from mlflow.entities.span_status import SpanStatus, SpanStatusCode
 30  from mlflow.exceptions import MlflowException
 31  from mlflow.langchain.langchain_tracer import MlflowLangchainTracer
 32  from mlflow.langchain.model import _LangChainModelWrapper
 33  from mlflow.tracing.constant import SpanAttributeKey
 34  from mlflow.tracing.provider import trace_disabled
 35  
 36  from tests.tracing.helper import get_traces
 37  
 38  # The mock OpenAI endpoint simply echos the prompt back as the completion.
 39  # So the expected output will be the prompt itself.
 40  TEST_CONTENT = "What is MLflow?"
 41  
 42  
 43  def create_openai_runnable(temperature=0.9):
 44      prompt = PromptTemplate(
 45          input_variables=["product"],
 46          template="What is {product}?",
 47      )
 48      llm = ChatOpenAI(temperature=temperature, stream_usage=True)
 49      return prompt | llm | StrOutputParser()
 50  
 51  
 52  def create_retriever():
 53      loader = TextLoader("tests/scoring/state_of_the_union.txt")
 54      documents = loader.load()
 55      text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
 56      docs = text_splitter.split_documents(documents)
 57      embeddings = FakeEmbeddings(size=5)
 58      db = FAISS.from_documents(docs, embeddings)
 59      return db.as_retriever()
 60  
 61  
 62  def _validate_trace_json_serialization(trace):
 63      trace_dict = trace.to_dict()
 64      trace_from_dict = Trace.from_dict(trace_dict)
 65      trace_json = trace.to_json()
 66      trace_from_json = Trace.from_json(trace_json)
 67      for loaded_trace in [trace_from_dict, trace_from_json]:
 68          assert trace.info == loaded_trace.info
 69          assert trace.data.request == loaded_trace.data.request
 70          assert trace.data.response == loaded_trace.data.response
 71          assert len(trace.data.spans) == len(loaded_trace.data.spans)
 72          for i in range(len(trace.data.spans)):
 73              for attr in [
 74                  "name",
 75                  "request_id",
 76                  "span_id",
 77                  "start_time_ns",
 78                  "end_time_ns",
 79                  "parent_id",
 80                  "status",
 81                  "inputs",
 82                  "outputs",
 83                  "_trace_id",
 84                  "attributes",
 85                  "events",
 86              ]:
 87                  assert getattr(trace.data.spans[i], attr) == getattr(
 88                      loaded_trace.data.spans[i], attr
 89                  )
 90  
 91  
 92  def test_llm_success():
 93      callback = MlflowLangchainTracer()
 94      run_id = str(uuid.uuid4())
 95      callback.on_llm_start(
 96          {},
 97          ["test prompt"],
 98          run_id=run_id,
 99          name="test_llm",
100      )
101  
102      callback.on_llm_new_token("test", run_id=run_id)
103  
104      callback.on_llm_end(LLMResult(generations=[[{"text": "generated text"}]]), run_id=run_id)
105      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
106      assert len(trace.data.spans) == 1
107      llm_span = trace.data.spans[0]
108  
109      assert llm_span.name == "test_llm"
110  
111      assert llm_span.span_type == "LLM"
112      assert llm_span.start_time_ns is not None
113      assert llm_span.end_time_ns is not None
114      assert llm_span.status == SpanStatus(SpanStatusCode.OK)
115      assert llm_span.inputs == ["test prompt"]
116      assert llm_span.outputs["choices"][0]["message"]["content"] == "generated text"
117      assert llm_span.events[0].name == "new_token"
118  
119      _validate_trace_json_serialization(trace)
120  
121  
122  def test_llm_error():
123      callback = MlflowLangchainTracer()
124      run_id = str(uuid.uuid4())
125      callback.on_llm_start(
126          {},
127          ["test prompt"],
128          run_id=run_id,
129          name="test_llm",
130      )
131      mock_error = Exception("mock exception")
132      callback.on_llm_error(error=mock_error, run_id=run_id)
133  
134      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
135      error_event = SpanEvent.from_exception(mock_error)
136      assert len(trace.data.spans) == 1
137      llm_span = trace.data.spans[0]
138      assert llm_span.status.status_code == SpanStatusCode.ERROR
139      assert llm_span.status.description == str(mock_error)
140      assert llm_span.inputs == ["test prompt"]
141      assert llm_span.outputs is None
142      # timestamp is auto-generated when converting the error to event
143      assert llm_span.events[0].name == error_event.name
144      assert llm_span.events[0].attributes == error_event.attributes
145  
146      _validate_trace_json_serialization(trace)
147  
148  
149  def test_llm_internal_exception():
150      callback = MlflowLangchainTracer()
151      run_id = str(uuid.uuid4())
152      callback.on_llm_start(
153          {},
154          ["test prompt"],
155          run_id=run_id,
156          name="test_llm",
157      )
158      try:
159          with pytest.raises(
160              Exception,
161              match="Span for run_id dummy not found.",
162          ):
163              callback.on_llm_end(LLMResult(generations=[[{"text": "generated"}]]), run_id="dummy")
164      finally:
165          callback.flush()
166  
167  
168  def test_chat_model():
169      callback = MlflowLangchainTracer()
170      run_id = str(uuid.uuid4())
171      input_messages = [SystemMessage("system prompt"), HumanMessage("test prompt")]
172      callback.on_chat_model_start(
173          {},
174          [input_messages],
175          run_id=run_id,
176          name="test_chat_model",
177      )
178      callback.on_llm_end(
179          LLMResult(generations=[[{"text": "generated text"}]]),
180          run_id=run_id,
181      )
182  
183      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
184      assert len(trace.data.spans) == 1
185      chat_model_span = trace.data.spans[0]
186      assert chat_model_span.name == "test_chat_model"
187      assert chat_model_span.span_type == "CHAT_MODEL"
188      assert chat_model_span.status.status_code == SpanStatusCode.OK
189      assert chat_model_span.inputs["messages"][0]["role"] == "system"
190      assert chat_model_span.inputs["messages"][0]["content"] == "system prompt"
191      assert chat_model_span.inputs["messages"][1]["role"] == "user"
192      assert chat_model_span.inputs["messages"][1]["content"] == "test prompt"
193      assert chat_model_span.outputs["choices"][0]["message"]["content"] == "generated text"
194  
195  
196  def test_chat_model_with_tool():
197      callback = MlflowLangchainTracer()
198      run_id = str(uuid.uuid4())
199      input_messages = [HumanMessage("test prompt")]
200      # OpenAI tool format
201      tool_definition = {
202          "type": "function",
203          "function": {
204              "name": "GetWeather",
205              "description": "Get the current weather in a given location",
206              "parameters": {
207                  "properties": {
208                      "location": {
209                          "description": "The city and state, e.g. San Francisco, CA",
210                          "type": "string",
211                      }
212                  },
213                  "required": ["location"],
214                  "type": "object",
215              },
216          },
217      }
218      callback.on_chat_model_start(
219          {},
220          [input_messages],
221          run_id=run_id,
222          name="test_chat_model",
223          invocation_params={"tools": [tool_definition]},
224      )
225      callback.on_llm_end(
226          LLMResult(generations=[[{"text": "generated text"}]]),
227          run_id=run_id,
228      )
229  
230      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
231      assert len(trace.data.spans) == 1
232      chat_model_span = trace.data.spans[0]
233      assert chat_model_span.status.status_code == SpanStatusCode.OK
234      assert chat_model_span.get_attribute(SpanAttributeKey.CHAT_TOOLS) == [tool_definition]
235  
236  
237  def test_chat_model_with_non_openai_tool():
238      callback = MlflowLangchainTracer()
239      run_id = str(uuid.uuid4())
240      input_messages = [HumanMessage("test prompt")]
241      # Anthropic tool format
242      tool_definition = {
243          "name": "get_weather",
244          "description": "Get the weather for a location.",
245          "input_schema": {
246              "properties": {
247                  "location": {
248                      "description": "The city and state, e.g. San Francisco, CA",
249                      "type": "string",
250                  }
251              },
252              "required": ["location"],
253              "type": "object",
254          },
255      }
256      callback.on_chat_model_start(
257          {},
258          [input_messages],
259          run_id=run_id,
260          name="test_chat_model",
261          invocation_params={"tools": [tool_definition]},
262      )
263      callback.on_llm_end(
264          LLMResult(generations=[[{"text": "generated text"}]]),
265          run_id=run_id,
266      )
267  
268      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
269      assert len(trace.data.spans) == 1
270      chat_model_span = trace.data.spans[0]
271      assert chat_model_span.status.status_code == SpanStatusCode.OK
272      assert chat_model_span.get_attribute(SpanAttributeKey.CHAT_TOOLS) == [
273          {
274              "type": "function",
275              "function": {
276                  "name": "get_weather",
277                  "description": "Get the weather for a location.",
278              },
279          }
280      ]
281  
282  
283  def test_retriever_success():
284      callback = MlflowLangchainTracer()
285      run_id = str(uuid.uuid4())
286      callback.on_retriever_start(
287          {},
288          query="test query",
289          run_id=run_id,
290          name="test_retriever",
291      )
292  
293      documents = [
294          Document(
295              page_content="document content 1",
296              metadata={"chunk_id": "1", "doc_uri": "uri1"},
297          ),
298          Document(
299              page_content="document content 2",
300              metadata={"chunk_id": "2", "doc_uri": "uri2"},
301          ),
302      ]
303      callback.on_retriever_end(documents, run_id=run_id)
304      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
305      assert len(trace.data.spans) == 1
306      retriever_span = trace.data.spans[0]
307  
308      assert retriever_span.name == "test_retriever"
309      assert retriever_span.span_type == "RETRIEVER"
310      assert retriever_span.inputs == "test query"
311      assert retriever_span.outputs == [
312          MlflowDocument.from_langchain_document(doc).to_dict() for doc in documents
313      ]
314      assert retriever_span.start_time_ns is not None
315      assert retriever_span.end_time_ns is not None
316      assert retriever_span.status.status_code == SpanStatusCode.OK
317  
318      _validate_trace_json_serialization(trace)
319  
320  
321  def test_retriever_error():
322      callback = MlflowLangchainTracer()
323      run_id = str(uuid.uuid4())
324      callback.on_retriever_start(
325          {},
326          query="test query",
327          run_id=run_id,
328          name="test_retriever",
329      )
330      mock_error = Exception("mock exception")
331      callback.on_retriever_error(error=mock_error, run_id=run_id)
332      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
333      assert len(trace.data.spans) == 1
334      retriever_span = trace.data.spans[0]
335      assert retriever_span.inputs == "test query"
336      assert retriever_span.outputs is None
337      error_event = SpanEvent.from_exception(mock_error)
338      assert retriever_span.status.status_code == SpanStatusCode.ERROR
339      assert retriever_span.events[0].name == error_event.name
340      assert retriever_span.events[0].attributes == error_event.attributes
341  
342      _validate_trace_json_serialization(trace)
343  
344  
345  def test_retriever_internal_exception():
346      callback = MlflowLangchainTracer()
347      run_id = str(uuid.uuid4())
348      callback.on_retriever_start(
349          {},
350          query="test query",
351          run_id=run_id,
352          name="test_retriever",
353      )
354  
355      try:
356          with pytest.raises(
357              Exception,
358              match="Span for run_id dummy not found.",
359          ):
360              callback.on_retriever_end(
361                  [
362                      Document(
363                          page_content="document content 1",
364                          metadata={"chunk_id": "1", "doc_uri": "uri1"},
365                      )
366                  ],
367                  run_id="dummy",
368              )
369      finally:
370          callback.flush()
371  
372  
373  def test_multiple_components():
374      callback = MlflowLangchainTracer()
375      chain_run_id = str(uuid.uuid4())
376      callback.on_chain_start(
377          {},
378          inputs={"input": "test input"},
379          run_id=chain_run_id,
380          name="test_chain",
381      )
382      for i in range(2):
383          llm_run_id = str(uuid.uuid4())
384          retriever_run_id = str(uuid.uuid4())
385          callback.on_llm_start(
386              {},
387              [f"test prompt {i}"],
388              run_id=llm_run_id,
389              name="test_llm",
390              parent_run_id=chain_run_id,
391          )
392          callback.on_retriever_start(
393              {},
394              query=f"test query {i}",
395              run_id=retriever_run_id,
396              name="test_retriever",
397              parent_run_id=llm_run_id,
398          )
399          callback.on_retriever_end(
400              [
401                  Document(
402                      page_content=f"document content {i}",
403                      metadata={
404                          "chunk_id": str(i),
405                          "doc_uri": f"https://mock_uri.com/{i}",
406                      },
407                  )
408              ],
409              run_id=retriever_run_id,
410          )
411          callback.on_llm_end(
412              LLMResult(generations=[[{"text": f"generated text {i}"}]]),
413              run_id=llm_run_id,
414          )
415      callback.on_chain_end(
416          outputs={"output": "test output"},
417          run_id=chain_run_id,
418      )
419      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
420      assert len(trace.data.spans) == 5
421      chain_span = trace.data.spans[0]
422      assert chain_span.start_time_ns is not None
423      assert chain_span.end_time_ns is not None
424      assert chain_span.name == "test_chain"
425      assert chain_span.span_type == "CHAIN"
426      assert chain_span.parent_id is None
427      assert chain_span.status.status_code == SpanStatusCode.OK
428      assert chain_span.inputs == {"input": "test input"}
429      assert chain_span.outputs == {"output": "test output"}
430      for i in range(2):
431          llm_span = trace.data.spans[1 + i * 2]
432          assert llm_span.inputs == [f"test prompt {i}"]
433          assert llm_span.outputs["choices"][0]["message"]["content"] == f"generated text {i}"
434          retriever_span = trace.data.spans[2 + i * 2]
435          assert retriever_span.inputs == f"test query {i}"
436          assert (
437              retriever_span.outputs[0]
438              == MlflowDocument(
439                  page_content=f"document content {i}",
440                  metadata={
441                      "chunk_id": str(i),
442                      "doc_uri": f"https://mock_uri.com/{i}",
443                  },
444              ).to_dict()
445          )
446  
447      _validate_trace_json_serialization(trace)
448  
449  
450  def test_tool_success():
451      callback = MlflowLangchainTracer()
452      prompt = SystemMessagePromptTemplate.from_template("You are a nice assistant.") + "{question}"
453      llm = ChatOpenAI()
454  
455      chain = prompt | llm | StrOutputParser()
456      chain_tool = tool("chain_tool", chain)
457  
458      tool_input = {"question": "What up"}
459      chain_tool.invoke(tool_input, config={"callbacks": [callback]})
460  
461      # str output is converted to _ChatResponse
462      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
463      spans = trace.data.spans
464      assert len(spans) == 5
465  
466      # Tool
467      tool_span = spans[0]
468      assert tool_span.span_type == "TOOL"
469      assert tool_span.inputs == tool_input
470      assert tool_span.outputs is not None
471      tool_span_id = tool_span.span_id
472  
473      # RunnableSequence
474      runnable_sequence_span = spans[1]
475      assert runnable_sequence_span.parent_id == tool_span_id
476      assert runnable_sequence_span.span_type == "CHAIN"
477      assert runnable_sequence_span.inputs == tool_input
478      assert runnable_sequence_span.outputs is not None
479  
480      # PromptTemplate
481      prompt_template_span = spans[2]
482      assert prompt_template_span.span_type == "CHAIN"
483      # LLM
484      llm_span = spans[3]
485      assert llm_span.span_type == "CHAT_MODEL"
486      # StrOutputParser
487      output_parser_span = spans[4]
488      assert output_parser_span.span_type == "CHAIN"
489      assert output_parser_span.outputs == [
490          {"content": "You are a nice assistant.", "role": "system"},
491          {"content": "What up", "role": "user"},
492      ]
493  
494      _validate_trace_json_serialization(trace)
495  
496  
497  def test_tracer_thread_safe():
498      tracer = MlflowLangchainTracer()
499  
500      def worker_function(worker_id):
501          chain_run_id = str(uuid.uuid4())
502          tracer.on_chain_start(
503              {}, {"input": "test input"}, run_id=chain_run_id, name=f"chain_{worker_id}"
504          )
505          # wait for a random time (0.5 ~ 1s) to simulate real-world scenario
506          time.sleep(random.random() / 2 + 0.5)
507          tracer.on_chain_end({"output": "test output"}, run_id=chain_run_id)
508  
509      with ThreadPoolExecutor(max_workers=10, thread_name_prefix="test-langchain-tracer") as executor:
510          futures = [executor.submit(worker_function, i) for i in range(10)]
511          for future in futures:
512              future.result()
513  
514      traces = get_traces()
515      assert len(traces) == 10
516      assert all(len(trace.data.spans) == 1 for trace in traces)
517  
518  
519  def test_tracer_does_not_add_spans_to_trace_after_root_run_has_finished():
520      class FakeChatModel(SimpleChatModel):
521          """Fake Chat Model wrapper for testing purposes."""
522  
523          def _call(self, messages: list[BaseMessage], **kwargs: Any) -> str:
524              return TEST_CONTENT
525  
526          @property
527          def _llm_type(self) -> str:
528              return "fake chat model"
529  
530      run_id_for_on_chain_end = None
531  
532      class ExceptionCatchingTracer(MlflowLangchainTracer):
533          def on_chain_end(self, outputs, *, run_id, inputs=None, **kwargs):
534              nonlocal run_id_for_on_chain_end
535              run_id_for_on_chain_end = run_id
536              super().on_chain_end(outputs, run_id=run_id, inputs=inputs, **kwargs)
537  
538      prompt = SystemMessagePromptTemplate.from_template("You are a nice assistant.") + "{question}"
539      chain = prompt | FakeChatModel() | StrOutputParser()
540  
541      tracer = ExceptionCatchingTracer()
542  
543      chain.invoke(
544          "What is MLflow?",
545          config={"callbacks": [tracer]},
546      )
547  
548      with pytest.raises(MlflowException, match="Span for run_id .* not found."):
549          # After the chain is invoked, verify that the tracer no longer holds references to spans,
550          # ensuring that the tracer does not add spans to the trace after the root run has finished
551          tracer.on_chain_end({"output": "test output"}, run_id=run_id_for_on_chain_end, inputs=None)
552  
553  
554  def test_tracer_noop_when_tracing_disabled(monkeypatch):
555      llm_chain = create_openai_runnable()
556      model = _LangChainModelWrapper(llm_chain)
557  
558      @trace_disabled
559      def _predict():
560          return model._predict_with_callbacks(
561              ["MLflow"],
562              callback_handlers=[MlflowLangchainTracer()],
563              convert_chat_responses=True,
564          )
565  
566      mock_logger = MagicMock()
567      monkeypatch.setattr(mlflow.tracking.client, "_logger", mock_logger)
568  
569      response = _predict()
570      assert response is not None
571      assert get_traces() == []
572      # No warning should be issued
573      mock_logger.warning.assert_not_called()
574  
575  
576  def test_tracer_with_manual_traces():
577      # Validate if the callback works properly when outer and inner spans
578      # are created by fluent APIs.
579      llm = ChatOpenAI()
580      prompt = PromptTemplate(
581          input_variables=["color"],
582          template="What is the complementary color of {color}?",
583      )
584  
585      # Inner spans are created within RunnableLambda
586      def foo(s: str):
587          with mlflow.start_span(name="foo_inner") as span:
588              span.set_inputs(s)
589              s = s.replace("red", "blue")
590              s = bar(s)
591              span.set_outputs(s)
592          return s
593  
594      @mlflow.trace
595      def bar(s):
596          return s.replace("blue", "green")
597  
598      chain = RunnableLambda(foo) | prompt | llm | StrOutputParser()
599  
600      @mlflow.trace(name="parent", span_type="SPECIAL")
601      def run(message):
602          return chain.invoke(message, config={"callbacks": [MlflowLangchainTracer()]})
603  
604      response = run("red")
605      expected_response = '[{"role": "user", "content": "What is the complementary color of green?"}]'
606      assert response == expected_response
607  
608      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
609      assert trace is not None
610      spans = trace.data.spans
611      assert spans[0].name == "parent"
612      assert spans[1].name == "RunnableSequence"
613      assert spans[1].parent_id == spans[0].span_id
614      assert spans[2].name == "foo"
615      assert spans[2].parent_id == spans[1].span_id
616      assert spans[3].name == "foo_inner"
617      assert spans[3].parent_id == spans[2].span_id
618      assert spans[4].name == "bar"
619      assert spans[4].parent_id == spans[3].span_id
620      assert spans[5].name == "PromptTemplate"
621      assert spans[5].parent_id == spans[1].span_id
622  
623  
624  def test_serialize_invocation_params_success():
625      class DummyModel(pydantic.BaseModel):
626          field: str
627  
628      callback = MlflowLangchainTracer()
629      attributes = {"invocation_params": {"response_format": DummyModel, "other_param": "preserved"}}
630      result = callback._serialize_invocation_params(attributes)
631      expected_schema = DummyModel.model_json_schema()
632      assert "invocation_params" in result
633      assert "response_format" in result["invocation_params"]
634      assert result["invocation_params"]["response_format"] == expected_schema
635      assert result["invocation_params"]["other_param"] == "preserved"
636  
637  
638  def test_serialize_invocation_params_failure():
639      class FaultyModel(pydantic.BaseModel):
640          field: str
641  
642          @classmethod
643          def model_json_schema(cls):
644              raise Exception("dummy failure")
645  
646      callback = MlflowLangchainTracer()
647      attributes = {"invocation_params": {"response_format": FaultyModel, "other_param": "preserved"}}
648      result = callback._serialize_invocation_params(attributes)
649      assert result["invocation_params"]["response_format"] == FaultyModel
650      assert result["invocation_params"]["other_param"] == "preserved"
651  
652  
653  def test_serialize_invocation_params_non_pydantic_response_format():
654      callback = MlflowLangchainTracer()
655      test_cases = ["string_value", {"dict_key": "value"}, 123, ["list", "of", "items"], None]
656  
657      for test_value in test_cases:
658          attributes = {
659              "invocation_params": {"response_format": test_value, "other_param": "preserved"}
660          }
661          result = callback._serialize_invocation_params(attributes)
662          assert result["invocation_params"]["response_format"] == test_value
663          assert result["invocation_params"]["other_param"] == "preserved"
664  
665  
666  def test_serialize_invocation_params_no_invocation_params():
667      callback = MlflowLangchainTracer()
668      attributes = {"other_key": "value"}
669      result = callback._serialize_invocation_params(attributes)
670      assert result == attributes
671  
672  
673  def test_serialize_invocation_params_none():
674      callback = MlflowLangchainTracer()
675      result = callback._serialize_invocation_params(None)
676      assert result is None
677  
678  
679  @pytest.mark.asyncio
680  async def test_tracer_with_manual_traces_async():
681      llm = ChatOpenAI()
682      prompt = PromptTemplate(
683          input_variables=["color"],
684          template="What is the complementary color of {color}?",
685      )
686  
687      @mlflow.trace
688      def manual_transform(s: str):
689          return s.replace("red", "blue")
690  
691      chain = RunnableLambda(manual_transform) | prompt | llm | StrOutputParser()
692  
693      @mlflow.trace(name="parent")
694      async def run(message):
695          # run_inline=True ensures proper context propagation in async scenarios
696          tracer = MlflowLangchainTracer(run_inline=True)
697          return await chain.ainvoke(message, config={"callbacks": [tracer]})
698  
699      response = await run("red")
700      expected_response = '[{"role": "user", "content": "What is the complementary color of blue?"}]'
701      assert response == expected_response
702  
703      traces = get_traces()
704      assert len(traces) == 1
705  
706      trace = traces[0]
707      spans = trace.data.spans
708      assert spans[0].name == "parent"
709      assert spans[1].name == "RunnableSequence"
710      assert spans[1].parent_id == spans[0].span_id
711      assert spans[2].name == "manual_transform"
712      assert spans[2].parent_id == spans[1].span_id
713  
714  
715  @pytest.mark.parametrize(
716      ("_type", "expected_provider"),
717      [
718          ("openai-chat", "openai"),
719          ("anthropic-chat", "anthropic"),
720          ("bedrock-chat", "bedrock"),
721          ("openai", "openai"),
722      ],
723  )
724  def test_chat_model_extracts_model_provider(_type, expected_provider):
725      callback = MlflowLangchainTracer()
726      run_id = str(uuid.uuid4())
727      callback.on_chat_model_start(
728          {},
729          [[HumanMessage("test")]],
730          run_id=run_id,
731          name="test_chat_model",
732          invocation_params={"model": "gpt-4", "_type": _type},
733      )
734      callback.on_llm_end(
735          LLMResult(generations=[[{"text": "response"}]]),
736          run_id=run_id,
737      )
738  
739      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
740      span = trace.data.spans[0]
741      assert span.get_attribute(SpanAttributeKey.MODEL) == "gpt-4"
742      assert span.get_attribute(SpanAttributeKey.MODEL_PROVIDER) == expected_provider
743  
744  
745  def test_chat_model_no_provider_when_type_missing():
746      callback = MlflowLangchainTracer()
747      run_id = str(uuid.uuid4())
748      callback.on_chat_model_start(
749          {},
750          [[HumanMessage("test")]],
751          run_id=run_id,
752          name="test_chat_model",
753          invocation_params={"model": "gpt-4"},
754      )
755      callback.on_llm_end(
756          LLMResult(generations=[[{"text": "response"}]]),
757          run_id=run_id,
758      )
759  
760      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
761      span = trace.data.spans[0]
762      assert span.get_attribute(SpanAttributeKey.MODEL) == "gpt-4"
763      assert span.get_attribute(SpanAttributeKey.MODEL_PROVIDER) is None
764  
765  
766  @pytest.mark.parametrize("run_tracer_inline", [True, False])
767  def test_tracer_run_inline_parameter(run_tracer_inline):
768      tracer = MlflowLangchainTracer(run_inline=run_tracer_inline)
769      assert tracer.run_inline == run_tracer_inline