/ mlflow / openai / autolog.py
autolog.py
  1  import importlib.metadata
  2  import json
  3  import logging
  4  from typing import Any, AsyncIterator, Iterator
  5  
  6  from packaging.version import Version
  7  
  8  import mlflow
  9  from mlflow.entities import SpanType
 10  from mlflow.entities.span import LiveSpan
 11  from mlflow.entities.span_event import SpanEvent
 12  from mlflow.entities.span_status import SpanStatusCode
 13  from mlflow.exceptions import MlflowException
 14  from mlflow.openai.constant import FLAVOR_NAME
 15  from mlflow.openai.utils.chat_schema import set_span_chat_attributes
 16  from mlflow.telemetry.events import AutologgingEvent
 17  from mlflow.telemetry.track import _record_event
 18  from mlflow.tracing.constant import (
 19      STREAM_CHUNK_EVENT_NAME_FORMAT,
 20      STREAM_CHUNK_EVENT_VALUE_KEY,
 21      SpanAttributeKey,
 22      TokenUsageKey,
 23      TraceMetadataKey,
 24  )
 25  from mlflow.tracing.distributed import _get_tracing_headers_from_span
 26  from mlflow.tracing.fluent import start_span_no_context
 27  from mlflow.tracing.trace_manager import InMemoryTraceManager
 28  from mlflow.tracing.utils import TraceJSONEncoder
 29  from mlflow.utils.autologging_utils import autologging_integration
 30  from mlflow.utils.autologging_utils.config import AutoLoggingConfig
 31  from mlflow.utils.autologging_utils.safety import safe_patch
 32  
 33  _logger = logging.getLogger(__name__)
 34  
 35  
 36  def autolog(
 37      disable=False,
 38      exclusive=False,
 39      disable_for_unsupported_versions=False,
 40      silent=False,
 41      log_traces=True,
 42      disable_openai_agent_tracer=True,
 43  ):
 44      """
 45      Enables (or disables) and configures autologging from OpenAI to MLflow.
 46      Raises :py:class:`MlflowException <mlflow.exceptions.MlflowException>`
 47      if the OpenAI version < 1.0.
 48  
 49      Args:
 50          disable: If ``True``, disables the OpenAI autologging integration. If ``False``,
 51              enables the OpenAI autologging integration.
 52          exclusive: If ``True``, autologged content is not logged to user-created fluent runs.
 53              If ``False``, autologged content is logged to the active fluent run,
 54              which may be user-created.
 55          disable_for_unsupported_versions: If ``True``, disable autologging for versions of
 56              OpenAI that have not been tested against this version of the MLflow
 57              client or are incompatible.
 58          silent: If ``True``, suppress all event logs and warnings from MLflow during OpenAI
 59              autologging. If ``False``, show all events and warnings during OpenAI
 60              autologging.
 61          log_traces: If ``True``, traces are logged for OpenAI models. If ``False``, no traces are
 62              collected during inference. Default to ``True``.
 63          disable_openai_agent_tracer: If ``True``, disable the OpenAI Agent SDK tracer. If ``False``,
 64              enable the OpenAI Agent SDK tracer. Default to ``True``.
 65      """
 66      if Version(importlib.metadata.version("openai")).major < 1:
 67          raise MlflowException("OpenAI autologging is only supported for openai >= 1.0.0")
 68  
 69      # This needs to be called before doing any safe-patching (otherwise safe-patch will be no-op).
 70      # TODO: since this implementation is inconsistent, explore a universal way to solve the issue.
 71      _autolog(
 72          disable=disable,
 73          exclusive=exclusive,
 74          disable_for_unsupported_versions=disable_for_unsupported_versions,
 75          silent=silent,
 76          log_traces=log_traces,
 77      )
 78  
 79      # Tracing OpenAI Agent SDK. This has to be done outside the function annotated with
 80      # `@autologging_integration` because the function is not executed when `disable=True`.
 81      try:
 82          from agents.run import AgentRunner
 83  
 84          from mlflow.openai._agent_tracer import _patched_agent_run
 85  
 86          # NB: The OpenAI's built-in tracer does not capture inputs/outputs of the
 87          # root span, which is not inconvenient. Therefore, we add a patch for the
 88          # runner.run() method instead.
 89          safe_patch(FLAVOR_NAME, AgentRunner, "run", _patched_agent_run)
 90  
 91          from mlflow.openai._agent_tracer import (
 92              add_mlflow_trace_processor,
 93              clear_trace_processors,
 94              remove_mlflow_trace_processor,
 95          )
 96  
 97          if disable or not log_traces:
 98              remove_mlflow_trace_processor()
 99          else:
100              if disable_openai_agent_tracer:
101                  clear_trace_processors()
102              add_mlflow_trace_processor()
103      except ImportError:
104          pass
105  
106      _record_event(
107          AutologgingEvent, {"flavor": FLAVOR_NAME, "log_traces": log_traces, "disable": disable}
108      )
109  
110  
111  # This is required by mlflow.autolog()
112  autolog.integration_name = FLAVOR_NAME
113  
114  
115  # NB: The @autologging_integration annotation must be applied here, and the callback injection
116  # needs to happen outside the annotated function. This is because the annotated function is NOT
117  # executed when disable=True is passed. This prevents us from removing our callback and patching
118  # when autologging is turned off.
119  @autologging_integration(FLAVOR_NAME)
120  def _autolog(
121      disable=False,
122      exclusive=False,
123      disable_for_unsupported_versions=False,
124      silent=False,
125      log_traces=True,
126  ):
127      from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
128      from openai.resources.chat.completions import Completions as ChatCompletions
129      from openai.resources.completions import AsyncCompletions, Completions
130      from openai.resources.embeddings import AsyncEmbeddings, Embeddings
131  
132      for task in (ChatCompletions, Completions, Embeddings):
133          safe_patch(FLAVOR_NAME, task, "create", patched_call)
134  
135      if hasattr(ChatCompletions, "parse"):
136          # In openai>=1.92.0, `ChatCompletions` has a `parse` method:
137          # https://github.com/openai/openai-python/commit/0e358ed66b317038705fb38958a449d284f3cb88
138          safe_patch(FLAVOR_NAME, ChatCompletions, "parse", patched_call)
139  
140      for task in (AsyncChatCompletions, AsyncCompletions, AsyncEmbeddings):
141          safe_patch(FLAVOR_NAME, task, "create", async_patched_call)
142  
143      try:
144          from openai.resources.images import AsyncImages, Images
145  
146          safe_patch(FLAVOR_NAME, Images, "generate", patched_call)
147          safe_patch(FLAVOR_NAME, AsyncImages, "generate", async_patched_call)
148      except ImportError:
149          pass
150  
151      if hasattr(AsyncChatCompletions, "parse"):
152          # In openai>=1.92.0, `AsyncChatCompletions` has a `parse` method:
153          # https://github.com/openai/openai-python/commit/0e358ed66b317038705fb38958a449d284f3cb88
154          safe_patch(FLAVOR_NAME, AsyncChatCompletions, "parse", async_patched_call)
155  
156      try:
157          from openai.resources.beta.chat.completions import AsyncCompletions, Completions
158      except ImportError:
159          pass
160      else:
161          safe_patch(FLAVOR_NAME, Completions, "parse", patched_call)
162          safe_patch(FLAVOR_NAME, AsyncCompletions, "parse", async_patched_call)
163  
164      try:
165          from openai.resources.responses import AsyncResponses, Responses
166      except ImportError:
167          pass
168      else:
169          safe_patch(FLAVOR_NAME, Responses, "create", patched_call)
170          safe_patch(FLAVOR_NAME, AsyncResponses, "create", async_patched_call)
171          safe_patch(FLAVOR_NAME, AsyncResponses, "parse", async_patched_call)
172          safe_patch(FLAVOR_NAME, Responses, "parse", patched_call)
173  
174  
175  def _get_span_type_and_message_format(task: type) -> tuple[str, str]:
176      from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
177      from openai.resources.chat.completions import Completions as ChatCompletions
178      from openai.resources.completions import AsyncCompletions, Completions
179      from openai.resources.embeddings import AsyncEmbeddings, Embeddings
180  
181      span_type_mapping = {
182          ChatCompletions: SpanType.CHAT_MODEL,
183          AsyncChatCompletions: SpanType.CHAT_MODEL,
184          Completions: SpanType.LLM,
185          AsyncCompletions: SpanType.LLM,
186          Embeddings: SpanType.EMBEDDING,
187          AsyncEmbeddings: SpanType.EMBEDDING,
188      }
189  
190      try:
191          from openai.resources.images import AsyncImages, Images
192  
193          span_type_mapping[Images] = SpanType.TOOL
194          span_type_mapping[AsyncImages] = SpanType.TOOL
195      except ImportError:
196          pass
197  
198      try:
199          # Only available in openai>=1.40.0
200          from openai.resources.beta.chat.completions import (
201              AsyncCompletions as BetaAsyncChatCompletions,
202          )
203          from openai.resources.beta.chat.completions import Completions as BetaChatCompletions
204  
205          span_type_mapping[BetaChatCompletions] = SpanType.CHAT_MODEL
206          span_type_mapping[BetaAsyncChatCompletions] = SpanType.CHAT_MODEL
207      except ImportError:
208          _logger.debug(
209              "Failed to import `BetaChatCompletions` or `BetaAsyncChatCompletions`", exc_info=True
210          )
211  
212      try:
213          # Responses API only available in openai>=1.66.0
214          from openai.resources.responses import AsyncResponses, Responses
215  
216          span_type_mapping[Responses] = SpanType.CHAT_MODEL
217          span_type_mapping[AsyncResponses] = SpanType.CHAT_MODEL
218      except ImportError:
219          pass
220  
221      return span_type_mapping.get(task, (SpanType.UNKNOWN, None))
222  
223  
224  def _try_parse_raw_response(response: Any) -> Any:
225      """
226      As documented at https://github.com/openai/openai-python/tree/52357cff50bee57ef442e94d78a0de38b4173fc2?tab=readme-ov-file#accessing-raw-response-data-eg-headers,
227      a `LegacyAPIResponse` (https://github.com/openai/openai-python/blob/52357cff50bee57ef442e94d78a0de38b4173fc2/src/openai/_legacy_response.py#L45)
228      object is returned when the `create` method is invoked with `with_raw_response`.
229      """
230      try:
231          from openai._legacy_response import LegacyAPIResponse
232      except ImportError:
233          _logger.debug("Failed to import `LegacyAPIResponse` from `openai._legacy_response`")
234          return response
235      if isinstance(response, LegacyAPIResponse):
236          try:
237              # `parse` returns either a `pydantic.BaseModel` or a `openai.Stream` object
238              # depending on whether the request has a `stream` parameter set to `True`.
239              return response.parse()
240          except Exception as e:
241              _logger.debug(f"Failed to parse {response} (type: {response.__class__}): {e}")
242  
243      return response
244  
245  
246  def _is_responses_api(original: Any) -> bool:
247      match getattr(original, "__qualname__", "").split("."):
248          case [class_name, _]:
249              return class_name in {"Responses", "AsyncResponses"}
250          case _:
251              return False
252  
253  
254  def patched_call(original, self, *args, **kwargs):
255      config = AutoLoggingConfig.init(flavor_name=mlflow.openai.FLAVOR_NAME)
256      active_run = mlflow.active_run()
257      run_id = active_run.info.run_id if active_run else None
258  
259      if config.log_traces:
260          span = _start_span(self, kwargs, run_id)
261          _inject_tracing_headers(kwargs, span)
262  
263      # Execute the original function
264      try:
265          raw_result = original(self, *args, **kwargs)
266      except Exception as e:
267          if config.log_traces:
268              _end_span_on_exception(span, e)
269          raise
270  
271      if config.log_traces:
272          _end_span_on_success(span, kwargs, raw_result, is_responses_api=_is_responses_api(original))
273  
274      return raw_result
275  
276  
277  async def async_patched_call(original, self, *args, **kwargs):
278      config = AutoLoggingConfig.init(flavor_name=mlflow.openai.FLAVOR_NAME)
279      active_run = mlflow.active_run()
280      run_id = active_run.info.run_id if active_run else None
281  
282      if config.log_traces:
283          span = _start_span(self, kwargs, run_id)
284          _inject_tracing_headers(kwargs, span)
285  
286      # Execute the original function
287      try:
288          raw_result = await original(self, *args, **kwargs)
289      except Exception as e:
290          if config.log_traces:
291              _end_span_on_exception(span, e)
292          raise
293  
294      if config.log_traces:
295          _end_span_on_success(span, kwargs, raw_result, is_responses_api=_is_responses_api(original))
296  
297      return raw_result
298  
299  
300  def _start_span(
301      instance: Any,
302      inputs: dict[str, Any],
303      run_id: str,
304  ):
305      span_type = _get_span_type_and_message_format(instance.__class__)
306      # Record input parameters to attributes
307      attributes = {k: v for k, v in inputs.items() if k not in ("messages", "input")}
308      if span_type in (SpanType.CHAT_MODEL, SpanType.LLM):
309          attributes[SpanAttributeKey.MESSAGE_FORMAT] = "openai"
310  
311      # If there is an active span, create a child span under it, otherwise create a new trace
312      span = start_span_no_context(
313          name=instance.__class__.__name__,
314          span_type=span_type,
315          inputs=inputs,
316          attributes=attributes,
317      )
318  
319      # Associate run ID to the trace manually, because if a new run is created by
320      # autologging, it is not set as the active run thus not automatically
321      # associated with the trace.
322      if run_id is not None:
323          tm = InMemoryTraceManager().get_instance()
324          tm.set_trace_metadata(span.trace_id, TraceMetadataKey.SOURCE_RUN, run_id)
325  
326      return span
327  
328  
329  def _end_span_on_success(
330      span: LiveSpan,
331      inputs: dict[str, Any],
332      raw_result: Any,
333      is_responses_api: bool,
334  ):
335      from openai import AsyncStream, Stream
336  
337      result = _try_parse_raw_response(raw_result)
338  
339      if isinstance(result, Stream):
340          # If the output is a stream, we add a hook to store the intermediate chunks
341          # and then log the outputs as a single artifact when the stream ends
342          def _stream_output_logging_hook(stream: Iterator) -> Iterator:
343              output = []
344              for i, chunk in enumerate(stream):
345                  _add_span_event(span, i, chunk)
346                  output.append(chunk)
347                  yield chunk
348              _process_last_chunk(span, chunk, inputs, output, is_responses_api)
349  
350          result._iterator = _stream_output_logging_hook(result._iterator)
351      elif isinstance(result, AsyncStream):
352  
353          async def _stream_output_logging_hook(stream: AsyncIterator) -> AsyncIterator:
354              output = []
355              async for chunk in stream:
356                  _add_span_event(span, len(output), chunk)
357                  output.append(chunk)
358                  yield chunk
359              _process_last_chunk(span, chunk, inputs, output, is_responses_api)
360  
361          result._iterator = _stream_output_logging_hook(result._iterator)
362      else:
363          try:
364              set_span_chat_attributes(span, inputs, result)
365              span.end(outputs=result)
366          except Exception as e:
367              _logger.warning(f"Encountered unexpected error when ending trace: {e}", exc_info=True)
368  
369  
370  def _process_last_chunk(
371      span: LiveSpan,
372      chunk: Any,
373      inputs: dict[str, Any],
374      output: list[Any],
375      is_responses_api: bool,
376  ) -> None:
377      try:
378          if _is_responses_final_event(chunk):
379              output = chunk.response
380          elif not output:
381              output = None
382          elif is_responses_api:
383              output = _reconstruct_response_from_stream(output)
384          elif output[0].object in ["text_completion", "chat.completion.chunk"]:
385              # Reconstruct a completion object from streaming chunks
386              output = _reconstruct_completion_from_stream(output)
387              # Set usage information on span if available
388              if usage := getattr(chunk, "usage", None):
389                  usage_dict = {
390                      TokenUsageKey.INPUT_TOKENS: usage.prompt_tokens,
391                      TokenUsageKey.OUTPUT_TOKENS: usage.completion_tokens,
392                      TokenUsageKey.TOTAL_TOKENS: usage.total_tokens,
393                  }
394  
395                  # Extract cached tokens if available in the streaming chunk
396                  if details := getattr(usage, "prompt_tokens_details", None):
397                      if (cached := getattr(details, "cached_tokens", None)) is not None:
398                          usage_dict[TokenUsageKey.CACHE_READ_INPUT_TOKENS] = cached
399                  span.set_attribute(SpanAttributeKey.CHAT_USAGE, usage_dict)
400  
401          _end_span_on_success(span, inputs, output, is_responses_api)
402      except Exception as e:
403          _logger.warning(
404              f"Encountered unexpected error when autologging processes the chunks in response: {e}"
405          )
406  
407  
408  def _reconstruct_completion_from_stream(chunks: list[Any]) -> Any:
409      """
410      Reconstruct a completion object from streaming chunks.
411  
412      This preserves the structure and metadata that would be present in a non-streaming
413      completion response, including ID, model, timestamps, usage, etc.
414      """
415      if chunks[0].object == "text_completion":
416          # Handling for the deprecated Completions API. Keep the legacy behavior for now.
417          def _extract_content(chunk: Any) -> str:
418              if not chunk.choices:
419                  return ""
420              return chunk.choices[0].text or ""
421  
422          return "".join(map(_extract_content, chunks))
423  
424      from openai.types.chat import ChatCompletion
425      from openai.types.chat.chat_completion import Choice
426      from openai.types.chat.chat_completion_message import ChatCompletionMessage
427  
428      # Build the base message
429      def _extract_content(chunk: Any) -> str:
430          if not chunk.choices:
431              return ""
432          content = chunk.choices[0].delta.content
433          if content is None:
434              return ""
435          # Handle Databricks streaming format where content can be a list of content items
436          # See https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/api-reference#content-item
437          if isinstance(content, list):
438              # Extract text from text items only.
439              text_parts = [
440                  item["text"]
441                  for item in content
442                  if isinstance(item, dict) and item.get("type") == "text" and "text" in item
443              ]
444              return "".join(text_parts)
445          return content
446  
447      message = ChatCompletionMessage(
448          role="assistant", content="".join(map(_extract_content, chunks))
449      )
450  
451      # Extract metadata from the last chunk
452      last_chunk = chunks[-1]
453      finish_reason = "stop"
454      if choices := getattr(last_chunk, "choices", None):
455          if chunk_choice := choices[0]:
456              finish_reason = getattr(chunk_choice, "finish_reason") or finish_reason
457  
458      choice = Choice(index=0, message=message, finish_reason=finish_reason)
459  
460      # Build the completion dict
461      return ChatCompletion(
462          id=last_chunk.id,
463          choices=[choice],
464          created=last_chunk.created,
465          model=last_chunk.model,
466          object="chat.completion",
467          system_fingerprint=last_chunk.system_fingerprint,
468          usage=last_chunk.usage,
469      )
470  
471  
472  def _reconstruct_response_from_stream(chunks: list[Any]) -> Any:
473      from openai.types.responses import ResponseOutputItemDoneEvent
474  
475      from mlflow.types.responses_helpers import Response
476  
477      output = [
478          chunk.item.to_dict() for chunk in chunks if isinstance(chunk, ResponseOutputItemDoneEvent)
479      ]
480  
481      return Response(output=output)
482  
483  
484  def _is_responses_final_event(chunk: Any) -> bool:
485      try:
486          from openai.types.responses import ResponseCompletedEvent
487  
488          return isinstance(chunk, ResponseCompletedEvent)
489      except ImportError:
490          return False
491  
492  
493  def _is_response_output_item_done_event(chunk: Any) -> bool:
494      try:
495          from openai.types.responses import ResponseOutputItemDoneEvent
496  
497          return isinstance(chunk, ResponseOutputItemDoneEvent)
498      except ImportError:
499          return False
500  
501  
502  def _inject_tracing_headers(kwargs: dict[str, Any], span: LiveSpan):
503      try:
504          if tracing_headers := _get_tracing_headers_from_span(span):
505              existing = kwargs.get("extra_headers") or {}
506              kwargs["extra_headers"] = tracing_headers | existing
507      except Exception:
508          _logger.debug("Failed to inject tracing headers", exc_info=True)
509  
510  
511  def _end_span_on_exception(span: LiveSpan, e: Exception):
512      try:
513          span.add_event(SpanEvent.from_exception(e))
514          span.end(status=SpanStatusCode.ERROR)
515      except Exception as inner_e:
516          _logger.warning(f"Encountered unexpected error when ending trace: {inner_e}")
517  
518  
519  def _add_span_event(span: LiveSpan, index: int, chunk: Any):
520      span.add_event(
521          SpanEvent(
522              name=STREAM_CHUNK_EVENT_NAME_FORMAT.format(index=index),
523              # OpenTelemetry SpanEvent only support str-str key-value pairs for attributes
524              attributes={STREAM_CHUNK_EVENT_VALUE_KEY: json.dumps(chunk, cls=TraceJSONEncoder)},
525          )
526      )