autolog.py
1 import importlib.metadata 2 import json 3 import logging 4 from typing import Any, AsyncIterator, Iterator 5 6 from packaging.version import Version 7 8 import mlflow 9 from mlflow.entities import SpanType 10 from mlflow.entities.span import LiveSpan 11 from mlflow.entities.span_event import SpanEvent 12 from mlflow.entities.span_status import SpanStatusCode 13 from mlflow.exceptions import MlflowException 14 from mlflow.openai.constant import FLAVOR_NAME 15 from mlflow.openai.utils.chat_schema import set_span_chat_attributes 16 from mlflow.telemetry.events import AutologgingEvent 17 from mlflow.telemetry.track import _record_event 18 from mlflow.tracing.constant import ( 19 STREAM_CHUNK_EVENT_NAME_FORMAT, 20 STREAM_CHUNK_EVENT_VALUE_KEY, 21 SpanAttributeKey, 22 TokenUsageKey, 23 TraceMetadataKey, 24 ) 25 from mlflow.tracing.distributed import _get_tracing_headers_from_span 26 from mlflow.tracing.fluent import start_span_no_context 27 from mlflow.tracing.trace_manager import InMemoryTraceManager 28 from mlflow.tracing.utils import TraceJSONEncoder 29 from mlflow.utils.autologging_utils import autologging_integration 30 from mlflow.utils.autologging_utils.config import AutoLoggingConfig 31 from mlflow.utils.autologging_utils.safety import safe_patch 32 33 _logger = logging.getLogger(__name__) 34 35 36 def autolog( 37 disable=False, 38 exclusive=False, 39 disable_for_unsupported_versions=False, 40 silent=False, 41 log_traces=True, 42 disable_openai_agent_tracer=True, 43 ): 44 """ 45 Enables (or disables) and configures autologging from OpenAI to MLflow. 46 Raises :py:class:`MlflowException <mlflow.exceptions.MlflowException>` 47 if the OpenAI version < 1.0. 48 49 Args: 50 disable: If ``True``, disables the OpenAI autologging integration. If ``False``, 51 enables the OpenAI autologging integration. 52 exclusive: If ``True``, autologged content is not logged to user-created fluent runs. 53 If ``False``, autologged content is logged to the active fluent run, 54 which may be user-created. 55 disable_for_unsupported_versions: If ``True``, disable autologging for versions of 56 OpenAI that have not been tested against this version of the MLflow 57 client or are incompatible. 58 silent: If ``True``, suppress all event logs and warnings from MLflow during OpenAI 59 autologging. If ``False``, show all events and warnings during OpenAI 60 autologging. 61 log_traces: If ``True``, traces are logged for OpenAI models. If ``False``, no traces are 62 collected during inference. Default to ``True``. 63 disable_openai_agent_tracer: If ``True``, disable the OpenAI Agent SDK tracer. If ``False``, 64 enable the OpenAI Agent SDK tracer. Default to ``True``. 65 """ 66 if Version(importlib.metadata.version("openai")).major < 1: 67 raise MlflowException("OpenAI autologging is only supported for openai >= 1.0.0") 68 69 # This needs to be called before doing any safe-patching (otherwise safe-patch will be no-op). 70 # TODO: since this implementation is inconsistent, explore a universal way to solve the issue. 71 _autolog( 72 disable=disable, 73 exclusive=exclusive, 74 disable_for_unsupported_versions=disable_for_unsupported_versions, 75 silent=silent, 76 log_traces=log_traces, 77 ) 78 79 # Tracing OpenAI Agent SDK. This has to be done outside the function annotated with 80 # `@autologging_integration` because the function is not executed when `disable=True`. 81 try: 82 from agents.run import AgentRunner 83 84 from mlflow.openai._agent_tracer import _patched_agent_run 85 86 # NB: The OpenAI's built-in tracer does not capture inputs/outputs of the 87 # root span, which is not inconvenient. Therefore, we add a patch for the 88 # runner.run() method instead. 89 safe_patch(FLAVOR_NAME, AgentRunner, "run", _patched_agent_run) 90 91 from mlflow.openai._agent_tracer import ( 92 add_mlflow_trace_processor, 93 clear_trace_processors, 94 remove_mlflow_trace_processor, 95 ) 96 97 if disable or not log_traces: 98 remove_mlflow_trace_processor() 99 else: 100 if disable_openai_agent_tracer: 101 clear_trace_processors() 102 add_mlflow_trace_processor() 103 except ImportError: 104 pass 105 106 _record_event( 107 AutologgingEvent, {"flavor": FLAVOR_NAME, "log_traces": log_traces, "disable": disable} 108 ) 109 110 111 # This is required by mlflow.autolog() 112 autolog.integration_name = FLAVOR_NAME 113 114 115 # NB: The @autologging_integration annotation must be applied here, and the callback injection 116 # needs to happen outside the annotated function. This is because the annotated function is NOT 117 # executed when disable=True is passed. This prevents us from removing our callback and patching 118 # when autologging is turned off. 119 @autologging_integration(FLAVOR_NAME) 120 def _autolog( 121 disable=False, 122 exclusive=False, 123 disable_for_unsupported_versions=False, 124 silent=False, 125 log_traces=True, 126 ): 127 from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions 128 from openai.resources.chat.completions import Completions as ChatCompletions 129 from openai.resources.completions import AsyncCompletions, Completions 130 from openai.resources.embeddings import AsyncEmbeddings, Embeddings 131 132 for task in (ChatCompletions, Completions, Embeddings): 133 safe_patch(FLAVOR_NAME, task, "create", patched_call) 134 135 if hasattr(ChatCompletions, "parse"): 136 # In openai>=1.92.0, `ChatCompletions` has a `parse` method: 137 # https://github.com/openai/openai-python/commit/0e358ed66b317038705fb38958a449d284f3cb88 138 safe_patch(FLAVOR_NAME, ChatCompletions, "parse", patched_call) 139 140 for task in (AsyncChatCompletions, AsyncCompletions, AsyncEmbeddings): 141 safe_patch(FLAVOR_NAME, task, "create", async_patched_call) 142 143 try: 144 from openai.resources.images import AsyncImages, Images 145 146 safe_patch(FLAVOR_NAME, Images, "generate", patched_call) 147 safe_patch(FLAVOR_NAME, AsyncImages, "generate", async_patched_call) 148 except ImportError: 149 pass 150 151 if hasattr(AsyncChatCompletions, "parse"): 152 # In openai>=1.92.0, `AsyncChatCompletions` has a `parse` method: 153 # https://github.com/openai/openai-python/commit/0e358ed66b317038705fb38958a449d284f3cb88 154 safe_patch(FLAVOR_NAME, AsyncChatCompletions, "parse", async_patched_call) 155 156 try: 157 from openai.resources.beta.chat.completions import AsyncCompletions, Completions 158 except ImportError: 159 pass 160 else: 161 safe_patch(FLAVOR_NAME, Completions, "parse", patched_call) 162 safe_patch(FLAVOR_NAME, AsyncCompletions, "parse", async_patched_call) 163 164 try: 165 from openai.resources.responses import AsyncResponses, Responses 166 except ImportError: 167 pass 168 else: 169 safe_patch(FLAVOR_NAME, Responses, "create", patched_call) 170 safe_patch(FLAVOR_NAME, AsyncResponses, "create", async_patched_call) 171 safe_patch(FLAVOR_NAME, AsyncResponses, "parse", async_patched_call) 172 safe_patch(FLAVOR_NAME, Responses, "parse", patched_call) 173 174 175 def _get_span_type_and_message_format(task: type) -> tuple[str, str]: 176 from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions 177 from openai.resources.chat.completions import Completions as ChatCompletions 178 from openai.resources.completions import AsyncCompletions, Completions 179 from openai.resources.embeddings import AsyncEmbeddings, Embeddings 180 181 span_type_mapping = { 182 ChatCompletions: SpanType.CHAT_MODEL, 183 AsyncChatCompletions: SpanType.CHAT_MODEL, 184 Completions: SpanType.LLM, 185 AsyncCompletions: SpanType.LLM, 186 Embeddings: SpanType.EMBEDDING, 187 AsyncEmbeddings: SpanType.EMBEDDING, 188 } 189 190 try: 191 from openai.resources.images import AsyncImages, Images 192 193 span_type_mapping[Images] = SpanType.TOOL 194 span_type_mapping[AsyncImages] = SpanType.TOOL 195 except ImportError: 196 pass 197 198 try: 199 # Only available in openai>=1.40.0 200 from openai.resources.beta.chat.completions import ( 201 AsyncCompletions as BetaAsyncChatCompletions, 202 ) 203 from openai.resources.beta.chat.completions import Completions as BetaChatCompletions 204 205 span_type_mapping[BetaChatCompletions] = SpanType.CHAT_MODEL 206 span_type_mapping[BetaAsyncChatCompletions] = SpanType.CHAT_MODEL 207 except ImportError: 208 _logger.debug( 209 "Failed to import `BetaChatCompletions` or `BetaAsyncChatCompletions`", exc_info=True 210 ) 211 212 try: 213 # Responses API only available in openai>=1.66.0 214 from openai.resources.responses import AsyncResponses, Responses 215 216 span_type_mapping[Responses] = SpanType.CHAT_MODEL 217 span_type_mapping[AsyncResponses] = SpanType.CHAT_MODEL 218 except ImportError: 219 pass 220 221 return span_type_mapping.get(task, (SpanType.UNKNOWN, None)) 222 223 224 def _try_parse_raw_response(response: Any) -> Any: 225 """ 226 As documented at https://github.com/openai/openai-python/tree/52357cff50bee57ef442e94d78a0de38b4173fc2?tab=readme-ov-file#accessing-raw-response-data-eg-headers, 227 a `LegacyAPIResponse` (https://github.com/openai/openai-python/blob/52357cff50bee57ef442e94d78a0de38b4173fc2/src/openai/_legacy_response.py#L45) 228 object is returned when the `create` method is invoked with `with_raw_response`. 229 """ 230 try: 231 from openai._legacy_response import LegacyAPIResponse 232 except ImportError: 233 _logger.debug("Failed to import `LegacyAPIResponse` from `openai._legacy_response`") 234 return response 235 if isinstance(response, LegacyAPIResponse): 236 try: 237 # `parse` returns either a `pydantic.BaseModel` or a `openai.Stream` object 238 # depending on whether the request has a `stream` parameter set to `True`. 239 return response.parse() 240 except Exception as e: 241 _logger.debug(f"Failed to parse {response} (type: {response.__class__}): {e}") 242 243 return response 244 245 246 def _is_responses_api(original: Any) -> bool: 247 match getattr(original, "__qualname__", "").split("."): 248 case [class_name, _]: 249 return class_name in {"Responses", "AsyncResponses"} 250 case _: 251 return False 252 253 254 def patched_call(original, self, *args, **kwargs): 255 config = AutoLoggingConfig.init(flavor_name=mlflow.openai.FLAVOR_NAME) 256 active_run = mlflow.active_run() 257 run_id = active_run.info.run_id if active_run else None 258 259 if config.log_traces: 260 span = _start_span(self, kwargs, run_id) 261 _inject_tracing_headers(kwargs, span) 262 263 # Execute the original function 264 try: 265 raw_result = original(self, *args, **kwargs) 266 except Exception as e: 267 if config.log_traces: 268 _end_span_on_exception(span, e) 269 raise 270 271 if config.log_traces: 272 _end_span_on_success(span, kwargs, raw_result, is_responses_api=_is_responses_api(original)) 273 274 return raw_result 275 276 277 async def async_patched_call(original, self, *args, **kwargs): 278 config = AutoLoggingConfig.init(flavor_name=mlflow.openai.FLAVOR_NAME) 279 active_run = mlflow.active_run() 280 run_id = active_run.info.run_id if active_run else None 281 282 if config.log_traces: 283 span = _start_span(self, kwargs, run_id) 284 _inject_tracing_headers(kwargs, span) 285 286 # Execute the original function 287 try: 288 raw_result = await original(self, *args, **kwargs) 289 except Exception as e: 290 if config.log_traces: 291 _end_span_on_exception(span, e) 292 raise 293 294 if config.log_traces: 295 _end_span_on_success(span, kwargs, raw_result, is_responses_api=_is_responses_api(original)) 296 297 return raw_result 298 299 300 def _start_span( 301 instance: Any, 302 inputs: dict[str, Any], 303 run_id: str, 304 ): 305 span_type = _get_span_type_and_message_format(instance.__class__) 306 # Record input parameters to attributes 307 attributes = {k: v for k, v in inputs.items() if k not in ("messages", "input")} 308 if span_type in (SpanType.CHAT_MODEL, SpanType.LLM): 309 attributes[SpanAttributeKey.MESSAGE_FORMAT] = "openai" 310 311 # If there is an active span, create a child span under it, otherwise create a new trace 312 span = start_span_no_context( 313 name=instance.__class__.__name__, 314 span_type=span_type, 315 inputs=inputs, 316 attributes=attributes, 317 ) 318 319 # Associate run ID to the trace manually, because if a new run is created by 320 # autologging, it is not set as the active run thus not automatically 321 # associated with the trace. 322 if run_id is not None: 323 tm = InMemoryTraceManager().get_instance() 324 tm.set_trace_metadata(span.trace_id, TraceMetadataKey.SOURCE_RUN, run_id) 325 326 return span 327 328 329 def _end_span_on_success( 330 span: LiveSpan, 331 inputs: dict[str, Any], 332 raw_result: Any, 333 is_responses_api: bool, 334 ): 335 from openai import AsyncStream, Stream 336 337 result = _try_parse_raw_response(raw_result) 338 339 if isinstance(result, Stream): 340 # If the output is a stream, we add a hook to store the intermediate chunks 341 # and then log the outputs as a single artifact when the stream ends 342 def _stream_output_logging_hook(stream: Iterator) -> Iterator: 343 output = [] 344 for i, chunk in enumerate(stream): 345 _add_span_event(span, i, chunk) 346 output.append(chunk) 347 yield chunk 348 _process_last_chunk(span, chunk, inputs, output, is_responses_api) 349 350 result._iterator = _stream_output_logging_hook(result._iterator) 351 elif isinstance(result, AsyncStream): 352 353 async def _stream_output_logging_hook(stream: AsyncIterator) -> AsyncIterator: 354 output = [] 355 async for chunk in stream: 356 _add_span_event(span, len(output), chunk) 357 output.append(chunk) 358 yield chunk 359 _process_last_chunk(span, chunk, inputs, output, is_responses_api) 360 361 result._iterator = _stream_output_logging_hook(result._iterator) 362 else: 363 try: 364 set_span_chat_attributes(span, inputs, result) 365 span.end(outputs=result) 366 except Exception as e: 367 _logger.warning(f"Encountered unexpected error when ending trace: {e}", exc_info=True) 368 369 370 def _process_last_chunk( 371 span: LiveSpan, 372 chunk: Any, 373 inputs: dict[str, Any], 374 output: list[Any], 375 is_responses_api: bool, 376 ) -> None: 377 try: 378 if _is_responses_final_event(chunk): 379 output = chunk.response 380 elif not output: 381 output = None 382 elif is_responses_api: 383 output = _reconstruct_response_from_stream(output) 384 elif output[0].object in ["text_completion", "chat.completion.chunk"]: 385 # Reconstruct a completion object from streaming chunks 386 output = _reconstruct_completion_from_stream(output) 387 # Set usage information on span if available 388 if usage := getattr(chunk, "usage", None): 389 usage_dict = { 390 TokenUsageKey.INPUT_TOKENS: usage.prompt_tokens, 391 TokenUsageKey.OUTPUT_TOKENS: usage.completion_tokens, 392 TokenUsageKey.TOTAL_TOKENS: usage.total_tokens, 393 } 394 395 # Extract cached tokens if available in the streaming chunk 396 if details := getattr(usage, "prompt_tokens_details", None): 397 if (cached := getattr(details, "cached_tokens", None)) is not None: 398 usage_dict[TokenUsageKey.CACHE_READ_INPUT_TOKENS] = cached 399 span.set_attribute(SpanAttributeKey.CHAT_USAGE, usage_dict) 400 401 _end_span_on_success(span, inputs, output, is_responses_api) 402 except Exception as e: 403 _logger.warning( 404 f"Encountered unexpected error when autologging processes the chunks in response: {e}" 405 ) 406 407 408 def _reconstruct_completion_from_stream(chunks: list[Any]) -> Any: 409 """ 410 Reconstruct a completion object from streaming chunks. 411 412 This preserves the structure and metadata that would be present in a non-streaming 413 completion response, including ID, model, timestamps, usage, etc. 414 """ 415 if chunks[0].object == "text_completion": 416 # Handling for the deprecated Completions API. Keep the legacy behavior for now. 417 def _extract_content(chunk: Any) -> str: 418 if not chunk.choices: 419 return "" 420 return chunk.choices[0].text or "" 421 422 return "".join(map(_extract_content, chunks)) 423 424 from openai.types.chat import ChatCompletion 425 from openai.types.chat.chat_completion import Choice 426 from openai.types.chat.chat_completion_message import ChatCompletionMessage 427 428 # Build the base message 429 def _extract_content(chunk: Any) -> str: 430 if not chunk.choices: 431 return "" 432 content = chunk.choices[0].delta.content 433 if content is None: 434 return "" 435 # Handle Databricks streaming format where content can be a list of content items 436 # See https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/api-reference#content-item 437 if isinstance(content, list): 438 # Extract text from text items only. 439 text_parts = [ 440 item["text"] 441 for item in content 442 if isinstance(item, dict) and item.get("type") == "text" and "text" in item 443 ] 444 return "".join(text_parts) 445 return content 446 447 message = ChatCompletionMessage( 448 role="assistant", content="".join(map(_extract_content, chunks)) 449 ) 450 451 # Extract metadata from the last chunk 452 last_chunk = chunks[-1] 453 finish_reason = "stop" 454 if choices := getattr(last_chunk, "choices", None): 455 if chunk_choice := choices[0]: 456 finish_reason = getattr(chunk_choice, "finish_reason") or finish_reason 457 458 choice = Choice(index=0, message=message, finish_reason=finish_reason) 459 460 # Build the completion dict 461 return ChatCompletion( 462 id=last_chunk.id, 463 choices=[choice], 464 created=last_chunk.created, 465 model=last_chunk.model, 466 object="chat.completion", 467 system_fingerprint=last_chunk.system_fingerprint, 468 usage=last_chunk.usage, 469 ) 470 471 472 def _reconstruct_response_from_stream(chunks: list[Any]) -> Any: 473 from openai.types.responses import ResponseOutputItemDoneEvent 474 475 from mlflow.types.responses_helpers import Response 476 477 output = [ 478 chunk.item.to_dict() for chunk in chunks if isinstance(chunk, ResponseOutputItemDoneEvent) 479 ] 480 481 return Response(output=output) 482 483 484 def _is_responses_final_event(chunk: Any) -> bool: 485 try: 486 from openai.types.responses import ResponseCompletedEvent 487 488 return isinstance(chunk, ResponseCompletedEvent) 489 except ImportError: 490 return False 491 492 493 def _is_response_output_item_done_event(chunk: Any) -> bool: 494 try: 495 from openai.types.responses import ResponseOutputItemDoneEvent 496 497 return isinstance(chunk, ResponseOutputItemDoneEvent) 498 except ImportError: 499 return False 500 501 502 def _inject_tracing_headers(kwargs: dict[str, Any], span: LiveSpan): 503 try: 504 if tracing_headers := _get_tracing_headers_from_span(span): 505 existing = kwargs.get("extra_headers") or {} 506 kwargs["extra_headers"] = tracing_headers | existing 507 except Exception: 508 _logger.debug("Failed to inject tracing headers", exc_info=True) 509 510 511 def _end_span_on_exception(span: LiveSpan, e: Exception): 512 try: 513 span.add_event(SpanEvent.from_exception(e)) 514 span.end(status=SpanStatusCode.ERROR) 515 except Exception as inner_e: 516 _logger.warning(f"Encountered unexpected error when ending trace: {inner_e}") 517 518 519 def _add_span_event(span: LiveSpan, index: int, chunk: Any): 520 span.add_event( 521 SpanEvent( 522 name=STREAM_CHUNK_EVENT_NAME_FORMAT.format(index=index), 523 # OpenTelemetry SpanEvent only support str-str key-value pairs for attributes 524 attributes={STREAM_CHUNK_EVENT_VALUE_KEY: json.dumps(chunk, cls=TraceJSONEncoder)}, 525 ) 526 )