test_langchain_tracer.py
1 import random 2 import time 3 import uuid 4 from concurrent.futures import ThreadPoolExecutor 5 from typing import Any 6 from unittest.mock import MagicMock 7 8 import pydantic 9 import pytest 10 from langchain_community.document_loaders import TextLoader 11 from langchain_community.embeddings import FakeEmbeddings 12 from langchain_community.vectorstores import FAISS 13 from langchain_core.documents import Document 14 from langchain_core.language_models.chat_models import SimpleChatModel 15 from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage 16 from langchain_core.output_parsers.string import StrOutputParser 17 from langchain_core.outputs import LLMResult 18 from langchain_core.prompts import PromptTemplate 19 from langchain_core.prompts.chat import SystemMessagePromptTemplate 20 from langchain_core.runnables import RunnableLambda 21 from langchain_core.tools import tool 22 from langchain_openai import ChatOpenAI 23 from langchain_text_splitters.character import CharacterTextSplitter 24 25 import mlflow 26 from mlflow.entities import Document as MlflowDocument 27 from mlflow.entities import Trace 28 from mlflow.entities.span_event import SpanEvent 29 from mlflow.entities.span_status import SpanStatus, SpanStatusCode 30 from mlflow.exceptions import MlflowException 31 from mlflow.langchain.langchain_tracer import MlflowLangchainTracer 32 from mlflow.langchain.model import _LangChainModelWrapper 33 from mlflow.tracing.constant import SpanAttributeKey 34 from mlflow.tracing.provider import trace_disabled 35 36 from tests.tracing.helper import get_traces 37 38 # The mock OpenAI endpoint simply echos the prompt back as the completion. 39 # So the expected output will be the prompt itself. 40 TEST_CONTENT = "What is MLflow?" 41 42 43 def create_openai_runnable(temperature=0.9): 44 prompt = PromptTemplate( 45 input_variables=["product"], 46 template="What is {product}?", 47 ) 48 llm = ChatOpenAI(temperature=temperature, stream_usage=True) 49 return prompt | llm | StrOutputParser() 50 51 52 def create_retriever(): 53 loader = TextLoader("tests/scoring/state_of_the_union.txt") 54 documents = loader.load() 55 text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) 56 docs = text_splitter.split_documents(documents) 57 embeddings = FakeEmbeddings(size=5) 58 db = FAISS.from_documents(docs, embeddings) 59 return db.as_retriever() 60 61 62 def _validate_trace_json_serialization(trace): 63 trace_dict = trace.to_dict() 64 trace_from_dict = Trace.from_dict(trace_dict) 65 trace_json = trace.to_json() 66 trace_from_json = Trace.from_json(trace_json) 67 for loaded_trace in [trace_from_dict, trace_from_json]: 68 assert trace.info == loaded_trace.info 69 assert trace.data.request == loaded_trace.data.request 70 assert trace.data.response == loaded_trace.data.response 71 assert len(trace.data.spans) == len(loaded_trace.data.spans) 72 for i in range(len(trace.data.spans)): 73 for attr in [ 74 "name", 75 "request_id", 76 "span_id", 77 "start_time_ns", 78 "end_time_ns", 79 "parent_id", 80 "status", 81 "inputs", 82 "outputs", 83 "_trace_id", 84 "attributes", 85 "events", 86 ]: 87 assert getattr(trace.data.spans[i], attr) == getattr( 88 loaded_trace.data.spans[i], attr 89 ) 90 91 92 def test_llm_success(): 93 callback = MlflowLangchainTracer() 94 run_id = str(uuid.uuid4()) 95 callback.on_llm_start( 96 {}, 97 ["test prompt"], 98 run_id=run_id, 99 name="test_llm", 100 ) 101 102 callback.on_llm_new_token("test", run_id=run_id) 103 104 callback.on_llm_end(LLMResult(generations=[[{"text": "generated text"}]]), run_id=run_id) 105 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 106 assert len(trace.data.spans) == 1 107 llm_span = trace.data.spans[0] 108 109 assert llm_span.name == "test_llm" 110 111 assert llm_span.span_type == "LLM" 112 assert llm_span.start_time_ns is not None 113 assert llm_span.end_time_ns is not None 114 assert llm_span.status == SpanStatus(SpanStatusCode.OK) 115 assert llm_span.inputs == ["test prompt"] 116 assert llm_span.outputs["choices"][0]["message"]["content"] == "generated text" 117 assert llm_span.events[0].name == "new_token" 118 119 _validate_trace_json_serialization(trace) 120 121 122 def test_llm_error(): 123 callback = MlflowLangchainTracer() 124 run_id = str(uuid.uuid4()) 125 callback.on_llm_start( 126 {}, 127 ["test prompt"], 128 run_id=run_id, 129 name="test_llm", 130 ) 131 mock_error = Exception("mock exception") 132 callback.on_llm_error(error=mock_error, run_id=run_id) 133 134 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 135 error_event = SpanEvent.from_exception(mock_error) 136 assert len(trace.data.spans) == 1 137 llm_span = trace.data.spans[0] 138 assert llm_span.status.status_code == SpanStatusCode.ERROR 139 assert llm_span.status.description == str(mock_error) 140 assert llm_span.inputs == ["test prompt"] 141 assert llm_span.outputs is None 142 # timestamp is auto-generated when converting the error to event 143 assert llm_span.events[0].name == error_event.name 144 assert llm_span.events[0].attributes == error_event.attributes 145 146 _validate_trace_json_serialization(trace) 147 148 149 def test_llm_internal_exception(): 150 callback = MlflowLangchainTracer() 151 run_id = str(uuid.uuid4()) 152 callback.on_llm_start( 153 {}, 154 ["test prompt"], 155 run_id=run_id, 156 name="test_llm", 157 ) 158 try: 159 with pytest.raises( 160 Exception, 161 match="Span for run_id dummy not found.", 162 ): 163 callback.on_llm_end(LLMResult(generations=[[{"text": "generated"}]]), run_id="dummy") 164 finally: 165 callback.flush() 166 167 168 def test_chat_model(): 169 callback = MlflowLangchainTracer() 170 run_id = str(uuid.uuid4()) 171 input_messages = [SystemMessage("system prompt"), HumanMessage("test prompt")] 172 callback.on_chat_model_start( 173 {}, 174 [input_messages], 175 run_id=run_id, 176 name="test_chat_model", 177 ) 178 callback.on_llm_end( 179 LLMResult(generations=[[{"text": "generated text"}]]), 180 run_id=run_id, 181 ) 182 183 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 184 assert len(trace.data.spans) == 1 185 chat_model_span = trace.data.spans[0] 186 assert chat_model_span.name == "test_chat_model" 187 assert chat_model_span.span_type == "CHAT_MODEL" 188 assert chat_model_span.status.status_code == SpanStatusCode.OK 189 assert chat_model_span.inputs["messages"][0]["role"] == "system" 190 assert chat_model_span.inputs["messages"][0]["content"] == "system prompt" 191 assert chat_model_span.inputs["messages"][1]["role"] == "user" 192 assert chat_model_span.inputs["messages"][1]["content"] == "test prompt" 193 assert chat_model_span.outputs["choices"][0]["message"]["content"] == "generated text" 194 195 196 def test_chat_model_with_tool(): 197 callback = MlflowLangchainTracer() 198 run_id = str(uuid.uuid4()) 199 input_messages = [HumanMessage("test prompt")] 200 # OpenAI tool format 201 tool_definition = { 202 "type": "function", 203 "function": { 204 "name": "GetWeather", 205 "description": "Get the current weather in a given location", 206 "parameters": { 207 "properties": { 208 "location": { 209 "description": "The city and state, e.g. San Francisco, CA", 210 "type": "string", 211 } 212 }, 213 "required": ["location"], 214 "type": "object", 215 }, 216 }, 217 } 218 callback.on_chat_model_start( 219 {}, 220 [input_messages], 221 run_id=run_id, 222 name="test_chat_model", 223 invocation_params={"tools": [tool_definition]}, 224 ) 225 callback.on_llm_end( 226 LLMResult(generations=[[{"text": "generated text"}]]), 227 run_id=run_id, 228 ) 229 230 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 231 assert len(trace.data.spans) == 1 232 chat_model_span = trace.data.spans[0] 233 assert chat_model_span.status.status_code == SpanStatusCode.OK 234 assert chat_model_span.get_attribute(SpanAttributeKey.CHAT_TOOLS) == [tool_definition] 235 236 237 def test_chat_model_with_non_openai_tool(): 238 callback = MlflowLangchainTracer() 239 run_id = str(uuid.uuid4()) 240 input_messages = [HumanMessage("test prompt")] 241 # Anthropic tool format 242 tool_definition = { 243 "name": "get_weather", 244 "description": "Get the weather for a location.", 245 "input_schema": { 246 "properties": { 247 "location": { 248 "description": "The city and state, e.g. San Francisco, CA", 249 "type": "string", 250 } 251 }, 252 "required": ["location"], 253 "type": "object", 254 }, 255 } 256 callback.on_chat_model_start( 257 {}, 258 [input_messages], 259 run_id=run_id, 260 name="test_chat_model", 261 invocation_params={"tools": [tool_definition]}, 262 ) 263 callback.on_llm_end( 264 LLMResult(generations=[[{"text": "generated text"}]]), 265 run_id=run_id, 266 ) 267 268 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 269 assert len(trace.data.spans) == 1 270 chat_model_span = trace.data.spans[0] 271 assert chat_model_span.status.status_code == SpanStatusCode.OK 272 assert chat_model_span.get_attribute(SpanAttributeKey.CHAT_TOOLS) == [ 273 { 274 "type": "function", 275 "function": { 276 "name": "get_weather", 277 "description": "Get the weather for a location.", 278 }, 279 } 280 ] 281 282 283 def test_retriever_success(): 284 callback = MlflowLangchainTracer() 285 run_id = str(uuid.uuid4()) 286 callback.on_retriever_start( 287 {}, 288 query="test query", 289 run_id=run_id, 290 name="test_retriever", 291 ) 292 293 documents = [ 294 Document( 295 page_content="document content 1", 296 metadata={"chunk_id": "1", "doc_uri": "uri1"}, 297 ), 298 Document( 299 page_content="document content 2", 300 metadata={"chunk_id": "2", "doc_uri": "uri2"}, 301 ), 302 ] 303 callback.on_retriever_end(documents, run_id=run_id) 304 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 305 assert len(trace.data.spans) == 1 306 retriever_span = trace.data.spans[0] 307 308 assert retriever_span.name == "test_retriever" 309 assert retriever_span.span_type == "RETRIEVER" 310 assert retriever_span.inputs == "test query" 311 assert retriever_span.outputs == [ 312 MlflowDocument.from_langchain_document(doc).to_dict() for doc in documents 313 ] 314 assert retriever_span.start_time_ns is not None 315 assert retriever_span.end_time_ns is not None 316 assert retriever_span.status.status_code == SpanStatusCode.OK 317 318 _validate_trace_json_serialization(trace) 319 320 321 def test_retriever_error(): 322 callback = MlflowLangchainTracer() 323 run_id = str(uuid.uuid4()) 324 callback.on_retriever_start( 325 {}, 326 query="test query", 327 run_id=run_id, 328 name="test_retriever", 329 ) 330 mock_error = Exception("mock exception") 331 callback.on_retriever_error(error=mock_error, run_id=run_id) 332 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 333 assert len(trace.data.spans) == 1 334 retriever_span = trace.data.spans[0] 335 assert retriever_span.inputs == "test query" 336 assert retriever_span.outputs is None 337 error_event = SpanEvent.from_exception(mock_error) 338 assert retriever_span.status.status_code == SpanStatusCode.ERROR 339 assert retriever_span.events[0].name == error_event.name 340 assert retriever_span.events[0].attributes == error_event.attributes 341 342 _validate_trace_json_serialization(trace) 343 344 345 def test_retriever_internal_exception(): 346 callback = MlflowLangchainTracer() 347 run_id = str(uuid.uuid4()) 348 callback.on_retriever_start( 349 {}, 350 query="test query", 351 run_id=run_id, 352 name="test_retriever", 353 ) 354 355 try: 356 with pytest.raises( 357 Exception, 358 match="Span for run_id dummy not found.", 359 ): 360 callback.on_retriever_end( 361 [ 362 Document( 363 page_content="document content 1", 364 metadata={"chunk_id": "1", "doc_uri": "uri1"}, 365 ) 366 ], 367 run_id="dummy", 368 ) 369 finally: 370 callback.flush() 371 372 373 def test_multiple_components(): 374 callback = MlflowLangchainTracer() 375 chain_run_id = str(uuid.uuid4()) 376 callback.on_chain_start( 377 {}, 378 inputs={"input": "test input"}, 379 run_id=chain_run_id, 380 name="test_chain", 381 ) 382 for i in range(2): 383 llm_run_id = str(uuid.uuid4()) 384 retriever_run_id = str(uuid.uuid4()) 385 callback.on_llm_start( 386 {}, 387 [f"test prompt {i}"], 388 run_id=llm_run_id, 389 name="test_llm", 390 parent_run_id=chain_run_id, 391 ) 392 callback.on_retriever_start( 393 {}, 394 query=f"test query {i}", 395 run_id=retriever_run_id, 396 name="test_retriever", 397 parent_run_id=llm_run_id, 398 ) 399 callback.on_retriever_end( 400 [ 401 Document( 402 page_content=f"document content {i}", 403 metadata={ 404 "chunk_id": str(i), 405 "doc_uri": f"https://mock_uri.com/{i}", 406 }, 407 ) 408 ], 409 run_id=retriever_run_id, 410 ) 411 callback.on_llm_end( 412 LLMResult(generations=[[{"text": f"generated text {i}"}]]), 413 run_id=llm_run_id, 414 ) 415 callback.on_chain_end( 416 outputs={"output": "test output"}, 417 run_id=chain_run_id, 418 ) 419 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 420 assert len(trace.data.spans) == 5 421 chain_span = trace.data.spans[0] 422 assert chain_span.start_time_ns is not None 423 assert chain_span.end_time_ns is not None 424 assert chain_span.name == "test_chain" 425 assert chain_span.span_type == "CHAIN" 426 assert chain_span.parent_id is None 427 assert chain_span.status.status_code == SpanStatusCode.OK 428 assert chain_span.inputs == {"input": "test input"} 429 assert chain_span.outputs == {"output": "test output"} 430 for i in range(2): 431 llm_span = trace.data.spans[1 + i * 2] 432 assert llm_span.inputs == [f"test prompt {i}"] 433 assert llm_span.outputs["choices"][0]["message"]["content"] == f"generated text {i}" 434 retriever_span = trace.data.spans[2 + i * 2] 435 assert retriever_span.inputs == f"test query {i}" 436 assert ( 437 retriever_span.outputs[0] 438 == MlflowDocument( 439 page_content=f"document content {i}", 440 metadata={ 441 "chunk_id": str(i), 442 "doc_uri": f"https://mock_uri.com/{i}", 443 }, 444 ).to_dict() 445 ) 446 447 _validate_trace_json_serialization(trace) 448 449 450 def test_tool_success(): 451 callback = MlflowLangchainTracer() 452 prompt = SystemMessagePromptTemplate.from_template("You are a nice assistant.") + "{question}" 453 llm = ChatOpenAI() 454 455 chain = prompt | llm | StrOutputParser() 456 chain_tool = tool("chain_tool", chain) 457 458 tool_input = {"question": "What up"} 459 chain_tool.invoke(tool_input, config={"callbacks": [callback]}) 460 461 # str output is converted to _ChatResponse 462 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 463 spans = trace.data.spans 464 assert len(spans) == 5 465 466 # Tool 467 tool_span = spans[0] 468 assert tool_span.span_type == "TOOL" 469 assert tool_span.inputs == tool_input 470 assert tool_span.outputs is not None 471 tool_span_id = tool_span.span_id 472 473 # RunnableSequence 474 runnable_sequence_span = spans[1] 475 assert runnable_sequence_span.parent_id == tool_span_id 476 assert runnable_sequence_span.span_type == "CHAIN" 477 assert runnable_sequence_span.inputs == tool_input 478 assert runnable_sequence_span.outputs is not None 479 480 # PromptTemplate 481 prompt_template_span = spans[2] 482 assert prompt_template_span.span_type == "CHAIN" 483 # LLM 484 llm_span = spans[3] 485 assert llm_span.span_type == "CHAT_MODEL" 486 # StrOutputParser 487 output_parser_span = spans[4] 488 assert output_parser_span.span_type == "CHAIN" 489 assert output_parser_span.outputs == [ 490 {"content": "You are a nice assistant.", "role": "system"}, 491 {"content": "What up", "role": "user"}, 492 ] 493 494 _validate_trace_json_serialization(trace) 495 496 497 def test_tracer_thread_safe(): 498 tracer = MlflowLangchainTracer() 499 500 def worker_function(worker_id): 501 chain_run_id = str(uuid.uuid4()) 502 tracer.on_chain_start( 503 {}, {"input": "test input"}, run_id=chain_run_id, name=f"chain_{worker_id}" 504 ) 505 # wait for a random time (0.5 ~ 1s) to simulate real-world scenario 506 time.sleep(random.random() / 2 + 0.5) 507 tracer.on_chain_end({"output": "test output"}, run_id=chain_run_id) 508 509 with ThreadPoolExecutor(max_workers=10, thread_name_prefix="test-langchain-tracer") as executor: 510 futures = [executor.submit(worker_function, i) for i in range(10)] 511 for future in futures: 512 future.result() 513 514 traces = get_traces() 515 assert len(traces) == 10 516 assert all(len(trace.data.spans) == 1 for trace in traces) 517 518 519 def test_tracer_does_not_add_spans_to_trace_after_root_run_has_finished(): 520 class FakeChatModel(SimpleChatModel): 521 """Fake Chat Model wrapper for testing purposes.""" 522 523 def _call(self, messages: list[BaseMessage], **kwargs: Any) -> str: 524 return TEST_CONTENT 525 526 @property 527 def _llm_type(self) -> str: 528 return "fake chat model" 529 530 run_id_for_on_chain_end = None 531 532 class ExceptionCatchingTracer(MlflowLangchainTracer): 533 def on_chain_end(self, outputs, *, run_id, inputs=None, **kwargs): 534 nonlocal run_id_for_on_chain_end 535 run_id_for_on_chain_end = run_id 536 super().on_chain_end(outputs, run_id=run_id, inputs=inputs, **kwargs) 537 538 prompt = SystemMessagePromptTemplate.from_template("You are a nice assistant.") + "{question}" 539 chain = prompt | FakeChatModel() | StrOutputParser() 540 541 tracer = ExceptionCatchingTracer() 542 543 chain.invoke( 544 "What is MLflow?", 545 config={"callbacks": [tracer]}, 546 ) 547 548 with pytest.raises(MlflowException, match="Span for run_id .* not found."): 549 # After the chain is invoked, verify that the tracer no longer holds references to spans, 550 # ensuring that the tracer does not add spans to the trace after the root run has finished 551 tracer.on_chain_end({"output": "test output"}, run_id=run_id_for_on_chain_end, inputs=None) 552 553 554 def test_tracer_noop_when_tracing_disabled(monkeypatch): 555 llm_chain = create_openai_runnable() 556 model = _LangChainModelWrapper(llm_chain) 557 558 @trace_disabled 559 def _predict(): 560 return model._predict_with_callbacks( 561 ["MLflow"], 562 callback_handlers=[MlflowLangchainTracer()], 563 convert_chat_responses=True, 564 ) 565 566 mock_logger = MagicMock() 567 monkeypatch.setattr(mlflow.tracking.client, "_logger", mock_logger) 568 569 response = _predict() 570 assert response is not None 571 assert get_traces() == [] 572 # No warning should be issued 573 mock_logger.warning.assert_not_called() 574 575 576 def test_tracer_with_manual_traces(): 577 # Validate if the callback works properly when outer and inner spans 578 # are created by fluent APIs. 579 llm = ChatOpenAI() 580 prompt = PromptTemplate( 581 input_variables=["color"], 582 template="What is the complementary color of {color}?", 583 ) 584 585 # Inner spans are created within RunnableLambda 586 def foo(s: str): 587 with mlflow.start_span(name="foo_inner") as span: 588 span.set_inputs(s) 589 s = s.replace("red", "blue") 590 s = bar(s) 591 span.set_outputs(s) 592 return s 593 594 @mlflow.trace 595 def bar(s): 596 return s.replace("blue", "green") 597 598 chain = RunnableLambda(foo) | prompt | llm | StrOutputParser() 599 600 @mlflow.trace(name="parent", span_type="SPECIAL") 601 def run(message): 602 return chain.invoke(message, config={"callbacks": [MlflowLangchainTracer()]}) 603 604 response = run("red") 605 expected_response = '[{"role": "user", "content": "What is the complementary color of green?"}]' 606 assert response == expected_response 607 608 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 609 assert trace is not None 610 spans = trace.data.spans 611 assert spans[0].name == "parent" 612 assert spans[1].name == "RunnableSequence" 613 assert spans[1].parent_id == spans[0].span_id 614 assert spans[2].name == "foo" 615 assert spans[2].parent_id == spans[1].span_id 616 assert spans[3].name == "foo_inner" 617 assert spans[3].parent_id == spans[2].span_id 618 assert spans[4].name == "bar" 619 assert spans[4].parent_id == spans[3].span_id 620 assert spans[5].name == "PromptTemplate" 621 assert spans[5].parent_id == spans[1].span_id 622 623 624 def test_serialize_invocation_params_success(): 625 class DummyModel(pydantic.BaseModel): 626 field: str 627 628 callback = MlflowLangchainTracer() 629 attributes = {"invocation_params": {"response_format": DummyModel, "other_param": "preserved"}} 630 result = callback._serialize_invocation_params(attributes) 631 expected_schema = DummyModel.model_json_schema() 632 assert "invocation_params" in result 633 assert "response_format" in result["invocation_params"] 634 assert result["invocation_params"]["response_format"] == expected_schema 635 assert result["invocation_params"]["other_param"] == "preserved" 636 637 638 def test_serialize_invocation_params_failure(): 639 class FaultyModel(pydantic.BaseModel): 640 field: str 641 642 @classmethod 643 def model_json_schema(cls): 644 raise Exception("dummy failure") 645 646 callback = MlflowLangchainTracer() 647 attributes = {"invocation_params": {"response_format": FaultyModel, "other_param": "preserved"}} 648 result = callback._serialize_invocation_params(attributes) 649 assert result["invocation_params"]["response_format"] == FaultyModel 650 assert result["invocation_params"]["other_param"] == "preserved" 651 652 653 def test_serialize_invocation_params_non_pydantic_response_format(): 654 callback = MlflowLangchainTracer() 655 test_cases = ["string_value", {"dict_key": "value"}, 123, ["list", "of", "items"], None] 656 657 for test_value in test_cases: 658 attributes = { 659 "invocation_params": {"response_format": test_value, "other_param": "preserved"} 660 } 661 result = callback._serialize_invocation_params(attributes) 662 assert result["invocation_params"]["response_format"] == test_value 663 assert result["invocation_params"]["other_param"] == "preserved" 664 665 666 def test_serialize_invocation_params_no_invocation_params(): 667 callback = MlflowLangchainTracer() 668 attributes = {"other_key": "value"} 669 result = callback._serialize_invocation_params(attributes) 670 assert result == attributes 671 672 673 def test_serialize_invocation_params_none(): 674 callback = MlflowLangchainTracer() 675 result = callback._serialize_invocation_params(None) 676 assert result is None 677 678 679 @pytest.mark.asyncio 680 async def test_tracer_with_manual_traces_async(): 681 llm = ChatOpenAI() 682 prompt = PromptTemplate( 683 input_variables=["color"], 684 template="What is the complementary color of {color}?", 685 ) 686 687 @mlflow.trace 688 def manual_transform(s: str): 689 return s.replace("red", "blue") 690 691 chain = RunnableLambda(manual_transform) | prompt | llm | StrOutputParser() 692 693 @mlflow.trace(name="parent") 694 async def run(message): 695 # run_inline=True ensures proper context propagation in async scenarios 696 tracer = MlflowLangchainTracer(run_inline=True) 697 return await chain.ainvoke(message, config={"callbacks": [tracer]}) 698 699 response = await run("red") 700 expected_response = '[{"role": "user", "content": "What is the complementary color of blue?"}]' 701 assert response == expected_response 702 703 traces = get_traces() 704 assert len(traces) == 1 705 706 trace = traces[0] 707 spans = trace.data.spans 708 assert spans[0].name == "parent" 709 assert spans[1].name == "RunnableSequence" 710 assert spans[1].parent_id == spans[0].span_id 711 assert spans[2].name == "manual_transform" 712 assert spans[2].parent_id == spans[1].span_id 713 714 715 @pytest.mark.parametrize( 716 ("_type", "expected_provider"), 717 [ 718 ("openai-chat", "openai"), 719 ("anthropic-chat", "anthropic"), 720 ("bedrock-chat", "bedrock"), 721 ("openai", "openai"), 722 ], 723 ) 724 def test_chat_model_extracts_model_provider(_type, expected_provider): 725 callback = MlflowLangchainTracer() 726 run_id = str(uuid.uuid4()) 727 callback.on_chat_model_start( 728 {}, 729 [[HumanMessage("test")]], 730 run_id=run_id, 731 name="test_chat_model", 732 invocation_params={"model": "gpt-4", "_type": _type}, 733 ) 734 callback.on_llm_end( 735 LLMResult(generations=[[{"text": "response"}]]), 736 run_id=run_id, 737 ) 738 739 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 740 span = trace.data.spans[0] 741 assert span.get_attribute(SpanAttributeKey.MODEL) == "gpt-4" 742 assert span.get_attribute(SpanAttributeKey.MODEL_PROVIDER) == expected_provider 743 744 745 def test_chat_model_no_provider_when_type_missing(): 746 callback = MlflowLangchainTracer() 747 run_id = str(uuid.uuid4()) 748 callback.on_chat_model_start( 749 {}, 750 [[HumanMessage("test")]], 751 run_id=run_id, 752 name="test_chat_model", 753 invocation_params={"model": "gpt-4"}, 754 ) 755 callback.on_llm_end( 756 LLMResult(generations=[[{"text": "response"}]]), 757 run_id=run_id, 758 ) 759 760 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 761 span = trace.data.spans[0] 762 assert span.get_attribute(SpanAttributeKey.MODEL) == "gpt-4" 763 assert span.get_attribute(SpanAttributeKey.MODEL_PROVIDER) is None 764 765 766 @pytest.mark.parametrize("run_tracer_inline", [True, False]) 767 def test_tracer_run_inline_parameter(run_tracer_inline): 768 tracer = MlflowLangchainTracer(run_inline=run_tracer_inline) 769 assert tracer.run_inline == run_tracer_inline