/ mlflow / models / rag_signatures.py
rag_signatures.py
  1  from dataclasses import dataclass, field
  2  
  3  from mlflow.models import ModelSignature
  4  from mlflow.types.schema import (
  5      Array,
  6      ColSpec,
  7      DataType,
  8      Object,
  9      Property,
 10      Schema,
 11  )
 12  from mlflow.utils.annotations import deprecated
 13  
 14  
 15  @deprecated("mlflow.types.llm.ChatMessage")
 16  @dataclass
 17  class Message:
 18      role: str = "user"  # "system", "user", or "assistant"
 19      content: str = "What is mlflow?"
 20  
 21  
 22  @deprecated("mlflow.types.llm.ChatCompletionRequest")
 23  @dataclass
 24  class ChatCompletionRequest:
 25      messages: list[Message] = field(default_factory=lambda: [Message()])
 26  
 27  
 28  @deprecated("mlflow.types.llm.ChatCompletionRequest")
 29  @dataclass
 30  class SplitChatMessagesRequest:
 31      query: str = "What is mlflow?"
 32      history: list[Message] | None = field(default_factory=list)
 33  
 34  
 35  @deprecated("mlflow.types.llm.ChatCompletionRequest")
 36  @dataclass
 37  class MultiturnChatRequest:
 38      query: str = "What is mlflow?"
 39      history: list[Message] | None = field(default_factory=list)
 40  
 41  
 42  @deprecated("mlflow.types.llm.ChatChoice")
 43  @dataclass
 44  class ChainCompletionChoice:
 45      index: int = 0
 46      message: Message = field(
 47          default_factory=lambda: Message(
 48              role="assistant",
 49              content="MLflow is an open source platform for the machine learning lifecycle.",
 50          )
 51      )
 52      finish_reason: str = "stop"
 53  
 54  
 55  @deprecated("mlflow.types.llm.ChatCompletionChunk")
 56  @dataclass
 57  class ChainCompletionChunk:
 58      index: int = 0
 59      delta: Message = field(
 60          default_factory=lambda: Message(
 61              role="assistant",
 62              content="MLflow is an open source platform for the machine learning lifecycle.",
 63          )
 64      )
 65      finish_reason: str = "stop"
 66  
 67  
 68  @deprecated("mlflow.types.llm.ChatCompletionResponse")
 69  @dataclass
 70  class ChatCompletionResponse:
 71      choices: list[ChainCompletionChoice] = field(default_factory=lambda: [ChainCompletionChoice()])
 72      object: str = "chat.completion"
 73      # TODO: support ChainCompletionChunk in the future
 74  
 75  
 76  @deprecated("mlflow.types.llm.ChatCompletionResponse")
 77  @dataclass
 78  class StringResponse:
 79      content: str = "MLflow is an open source platform for the machine learning lifecycle."
 80  
 81  
 82  CHAT_COMPLETION_REQUEST_SCHEMA = Schema([
 83      ColSpec(
 84          name="messages",
 85          type=Array(
 86              Object([
 87                  Property("role", DataType.string),
 88                  Property("content", DataType.string),
 89              ])
 90          ),
 91      ),
 92  ])
 93  
 94  CHAT_COMPLETION_RESPONSE_SCHEMA = Schema([
 95      ColSpec(
 96          name="choices",
 97          type=Array(
 98              Object([
 99                  Property("index", DataType.long),
100                  Property(
101                      "message",
102                      Object([
103                          Property("role", DataType.string),
104                          Property("content", DataType.string),
105                      ]),
106                  ),
107                  Property("finish_reason", DataType.string),
108              ])
109          ),
110      ),
111  ])
112  
113  SIGNATURE_FOR_LLM_INFERENCE_TASK = {
114      "llm/v1/chat": ModelSignature(
115          inputs=CHAT_COMPLETION_REQUEST_SCHEMA, outputs=CHAT_COMPLETION_RESPONSE_SCHEMA
116      ),
117  }