/ mlflow / langchain / output_parsers.py
output_parsers.py
  1  from dataclasses import asdict
  2  from typing import Any, AsyncIterator, Iterator
  3  from uuid import uuid4
  4  
  5  from langchain_core.messages.base import BaseMessage
  6  from langchain_core.output_parsers.transform import BaseTransformOutputParser
  7  
  8  from mlflow.models.rag_signatures import (
  9      ChainCompletionChoice,
 10      Message,
 11      StringResponse,
 12  )
 13  from mlflow.models.rag_signatures import (
 14      ChatCompletionResponse as RagChatCompletionResponse,
 15  )
 16  from mlflow.types.agent import ChatAgentChunk, ChatAgentMessage, ChatAgentResponse
 17  from mlflow.types.llm import (
 18      ChatChoice,
 19      ChatChoiceDelta,
 20      ChatChunkChoice,
 21      ChatCompletionChunk,
 22      ChatCompletionResponse,
 23      ChatMessage,
 24  )
 25  from mlflow.utils.annotations import deprecated
 26  
 27  
 28  @deprecated("mlflow.langchain.output_parser.ChatCompletionOutputParser")
 29  class ChatCompletionsOutputParser(BaseTransformOutputParser[dict[str, Any]]):
 30      """
 31      OutputParser that wraps the string output into a dictionary representation of a
 32      :py:class:`ChatCompletionResponse`
 33      """
 34  
 35      @classmethod
 36      def is_lc_serializable(cls) -> bool:
 37          """Return whether this class is serializable."""
 38          return True
 39  
 40      @property
 41      def _type(self) -> str:
 42          """Return the output parser type for serialization."""
 43          return "mlflow_simplified_chat_completions"
 44  
 45      def parse(self, text: str) -> dict[str, Any]:
 46          return asdict(
 47              RagChatCompletionResponse(
 48                  choices=[ChainCompletionChoice(message=Message(role="assistant", content=text))],
 49                  object="chat.completion",
 50              )
 51          )
 52  
 53  
 54  class ChatCompletionOutputParser(BaseTransformOutputParser[str]):
 55      """
 56      OutputParser that wraps the string output into a dictionary representation of a
 57      :py:class:`ChatCompletionResponse` or :py:class:`ChatCompletionChunk`
 58      when streaming
 59      """
 60  
 61      @classmethod
 62      def is_lc_serializable(cls) -> bool:
 63          """Return whether this class is serializable."""
 64          return True
 65  
 66      @property
 67      def _type(self) -> str:
 68          """Return the output parser type for serialization."""
 69          return "mlflow_chat_completion"
 70  
 71      def parse(self, text: str) -> dict[str, Any]:
 72          """Returns the input text as a ChatCompletionResponse with no changes."""
 73          return ChatCompletionResponse(
 74              choices=[ChatChoice(message=ChatMessage(role="assistant", content=text))]
 75          ).to_dict()
 76  
 77      def transform(self, input: Iterator[BaseMessage], config, **kwargs) -> Iterator[dict[str, Any]]:
 78          """Returns a generator of ChatCompletionChunk objects"""
 79          for chunk in input:
 80              yield ChatCompletionChunk(
 81                  choices=[ChatChunkChoice(delta=ChatChoiceDelta(content=chunk.content))]
 82              ).to_dict()
 83  
 84      async def atransform(
 85          self,
 86          input: AsyncIterator[BaseMessage],
 87          config: Any,
 88          **kwargs: Any,
 89      ) -> AsyncIterator[ChatCompletionChunk]:
 90          async for chunk in input:
 91              yield ChatCompletionChunk(
 92                  choices=[ChatChunkChoice(delta=ChatChoiceDelta(content=chunk.content))]
 93              ).to_dict()
 94  
 95  
 96  @deprecated("mlflow.langchain.output_parser.ChatCompletionOutputParser")
 97  class StringResponseOutputParser(BaseTransformOutputParser[dict[str, Any]]):
 98      """
 99      OutputParser that wraps the string output into an dictionary representation of a
100      :py:class:`StringResponse`
101      """
102  
103      @classmethod
104      def is_lc_serializable(cls) -> bool:
105          """Return whether this class is serializable."""
106          return True
107  
108      @property
109      def _type(self) -> str:
110          """Return the output parser type for serialization."""
111          return "mlflow_simplified_str_object"
112  
113      def parse(self, text: str) -> dict[str, Any]:
114          return asdict(StringResponse(content=text))
115  
116  
117  class ChatAgentOutputParser(BaseTransformOutputParser[str]):
118      """
119      OutputParser that wraps the string output into a dictionary representation of a
120      :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` or a
121      :py:class:`ChatAgentChunk <mlflow.types.agent.ChatAgentChunk>` for easy interoperability.
122      """
123  
124      @classmethod
125      def is_lc_serializable(cls) -> bool:
126          """Return whether this class is serializable."""
127          return True
128  
129      @property
130      def _type(self) -> str:
131          """Return the output parser type for serialization."""
132          return "mlflow_chat_agent"
133  
134      def parse(self, text: str) -> dict[str, Any]:
135          """
136          Returns the output text as a dictionary representation of a
137          :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>`.
138          """
139          return ChatAgentResponse(
140              messages=[ChatAgentMessage(content=text, role="assistant", id=str(uuid4()))]
141          ).model_dump(exclude_none=True)
142  
143      def transform(self, input: Iterator[BaseMessage], config, **kwargs) -> Iterator[dict[str, Any]]:
144          """
145          Returns a generator of
146          :py:class:`ChatAgentChunk <mlflow.types.agent.ChatAgentChunk>` objects
147          """
148          for chunk in input:
149              if chunk.content:
150                  yield ChatAgentChunk(
151                      delta=ChatAgentMessage(content=chunk.content, role="assistant", id=chunk.id)
152                  ).model_dump(exclude_none=True)