rag_signatures.py
1 from dataclasses import dataclass, field 2 3 from mlflow.models import ModelSignature 4 from mlflow.types.schema import ( 5 Array, 6 ColSpec, 7 DataType, 8 Object, 9 Property, 10 Schema, 11 ) 12 from mlflow.utils.annotations import deprecated 13 14 15 @deprecated("mlflow.types.llm.ChatMessage") 16 @dataclass 17 class Message: 18 role: str = "user" # "system", "user", or "assistant" 19 content: str = "What is mlflow?" 20 21 22 @deprecated("mlflow.types.llm.ChatCompletionRequest") 23 @dataclass 24 class ChatCompletionRequest: 25 messages: list[Message] = field(default_factory=lambda: [Message()]) 26 27 28 @deprecated("mlflow.types.llm.ChatCompletionRequest") 29 @dataclass 30 class SplitChatMessagesRequest: 31 query: str = "What is mlflow?" 32 history: list[Message] | None = field(default_factory=list) 33 34 35 @deprecated("mlflow.types.llm.ChatCompletionRequest") 36 @dataclass 37 class MultiturnChatRequest: 38 query: str = "What is mlflow?" 39 history: list[Message] | None = field(default_factory=list) 40 41 42 @deprecated("mlflow.types.llm.ChatChoice") 43 @dataclass 44 class ChainCompletionChoice: 45 index: int = 0 46 message: Message = field( 47 default_factory=lambda: Message( 48 role="assistant", 49 content="MLflow is an open source platform for the machine learning lifecycle.", 50 ) 51 ) 52 finish_reason: str = "stop" 53 54 55 @deprecated("mlflow.types.llm.ChatCompletionChunk") 56 @dataclass 57 class ChainCompletionChunk: 58 index: int = 0 59 delta: Message = field( 60 default_factory=lambda: Message( 61 role="assistant", 62 content="MLflow is an open source platform for the machine learning lifecycle.", 63 ) 64 ) 65 finish_reason: str = "stop" 66 67 68 @deprecated("mlflow.types.llm.ChatCompletionResponse") 69 @dataclass 70 class ChatCompletionResponse: 71 choices: list[ChainCompletionChoice] = field(default_factory=lambda: [ChainCompletionChoice()]) 72 object: str = "chat.completion" 73 # TODO: support ChainCompletionChunk in the future 74 75 76 @deprecated("mlflow.types.llm.ChatCompletionResponse") 77 @dataclass 78 class StringResponse: 79 content: str = "MLflow is an open source platform for the machine learning lifecycle." 80 81 82 CHAT_COMPLETION_REQUEST_SCHEMA = Schema([ 83 ColSpec( 84 name="messages", 85 type=Array( 86 Object([ 87 Property("role", DataType.string), 88 Property("content", DataType.string), 89 ]) 90 ), 91 ), 92 ]) 93 94 CHAT_COMPLETION_RESPONSE_SCHEMA = Schema([ 95 ColSpec( 96 name="choices", 97 type=Array( 98 Object([ 99 Property("index", DataType.long), 100 Property( 101 "message", 102 Object([ 103 Property("role", DataType.string), 104 Property("content", DataType.string), 105 ]), 106 ), 107 Property("finish_reason", DataType.string), 108 ]) 109 ), 110 ), 111 ]) 112 113 SIGNATURE_FOR_LLM_INFERENCE_TASK = { 114 "llm/v1/chat": ModelSignature( 115 inputs=CHAT_COMPLETION_REQUEST_SCHEMA, outputs=CHAT_COMPLETION_RESPONSE_SCHEMA 116 ), 117 }