llm_inference_utils.py
1 from __future__ import annotations 2 3 import time 4 import uuid 5 from typing import TYPE_CHECKING, Any 6 7 import numpy as np 8 import pandas as pd 9 10 from mlflow.exceptions import MlflowException 11 from mlflow.models import ModelSignature 12 from mlflow.protos.databricks_pb2 import BAD_REQUEST, INVALID_PARAMETER_VALUE 13 from mlflow.transformers.flavor_config import FlavorKey 14 from mlflow.types.llm import ( 15 CHAT_MODEL_INPUT_SCHEMA, 16 CHAT_MODEL_OUTPUT_SCHEMA, 17 COMPLETIONS_MODEL_INPUT_SCHEMA, 18 COMPLETIONS_MODEL_OUTPUT_SCHEMA, 19 EMBEDDING_MODEL_INPUT_SCHEMA, 20 EMBEDDING_MODEL_OUTPUT_SCHEMA, 21 ) 22 23 if TYPE_CHECKING: 24 import torch 25 26 _LLM_INFERENCE_TASK_KEY = "inference_task" 27 # The LLM inference task is saved as "task" in the metadata for forward compatibility with 28 # future Databricks Provisioned Throughput support of more model architectures for inference. 29 _METADATA_LLM_INFERENCE_TASK_KEY = "task" 30 31 _LLM_INFERENCE_TASK_PREFIX = "llm/v1" 32 _LLM_INFERENCE_TASK_COMPLETIONS = f"{_LLM_INFERENCE_TASK_PREFIX}/completions" 33 _LLM_INFERENCE_TASK_CHAT = f"{_LLM_INFERENCE_TASK_PREFIX}/chat" 34 _LLM_INFERENCE_TASK_EMBEDDING = f"{_LLM_INFERENCE_TASK_PREFIX}/embeddings" 35 36 _LLM_V1_EMBEDDING_INPUT_KEY = "input" 37 38 39 _LLM_INFERENCE_OBJECT_NAME = { 40 _LLM_INFERENCE_TASK_COMPLETIONS: "text_completion", 41 _LLM_INFERENCE_TASK_CHAT: "chat.completion", 42 } 43 44 _SUPPORTED_LLM_INFERENCE_TASK_TYPES_BY_PIPELINE_TASK = { 45 "text-generation": [_LLM_INFERENCE_TASK_COMPLETIONS, _LLM_INFERENCE_TASK_CHAT], 46 "feature-extraction": [_LLM_INFERENCE_TASK_EMBEDDING], 47 } 48 49 _SIGNATURE_FOR_LLM_INFERENCE_TASK = { 50 _LLM_INFERENCE_TASK_CHAT: ModelSignature( 51 inputs=CHAT_MODEL_INPUT_SCHEMA, outputs=CHAT_MODEL_OUTPUT_SCHEMA 52 ), 53 _LLM_INFERENCE_TASK_COMPLETIONS: ModelSignature( 54 inputs=COMPLETIONS_MODEL_INPUT_SCHEMA, outputs=COMPLETIONS_MODEL_OUTPUT_SCHEMA 55 ), 56 _LLM_INFERENCE_TASK_EMBEDDING: ModelSignature( 57 inputs=EMBEDDING_MODEL_INPUT_SCHEMA, outputs=EMBEDDING_MODEL_OUTPUT_SCHEMA 58 ), 59 } 60 61 _LLM_INFERENCE_TASK_TO_DATA_FIELD = { 62 _LLM_INFERENCE_TASK_CHAT: "messages", 63 _LLM_INFERENCE_TASK_COMPLETIONS: "prompt", 64 } 65 66 67 def infer_signature_from_llm_inference_task( 68 inference_task: str, signature: ModelSignature | None = None 69 ) -> ModelSignature: 70 """ 71 Infers the signature according to the MLflow inference task. 72 Raises exception if a signature is given. 73 """ 74 inferred_signature = _SIGNATURE_FOR_LLM_INFERENCE_TASK[inference_task] 75 76 if signature is not None and signature != inferred_signature: 77 raise MlflowException( 78 f"When `task` is specified as `{inference_task}`, the signature would " 79 "be set by MLflow. Please do not set the signature." 80 ) 81 return inferred_signature 82 83 84 def convert_messages_to_prompt(messages: list[dict[str, Any]], tokenizer) -> str: 85 """For the Chat inference task, apply chat template to messages to create prompt. 86 87 Args: 88 messages: List of message e.g. [{"role": user, "content": xxx}, ...] 89 tokenizer: The tokenizer object used for inference. 90 91 Returns: 92 The prompt string contains the messages. 93 """ 94 if not (isinstance(messages, list) and all(isinstance(msg, dict) for msg in messages)): 95 raise MlflowException( 96 f"Input messages should be list of dictionaries, but got: {type(messages)}.", 97 error_code=INVALID_PARAMETER_VALUE, 98 ) 99 100 try: 101 return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 102 except Exception as e: 103 raise MlflowException(f"Failed to apply chat template: {e}") 104 105 106 def preprocess_llm_inference_input( 107 data: pd.DataFrame | dict[str, Any], 108 params: dict[str, Any] | None = None, 109 flavor_config: dict[str, Any] | None = None, 110 ) -> tuple[list[Any], dict[str, Any]]: 111 """ 112 When a MLflow inference task is given, return updated `data` and `params` that 113 - Extract the parameters from the input data (from the first row if passed multiple rows) 114 - Replace OpenAI specific parameters with Hugging Face specific parameters, in particular 115 - `max_tokens` with `max_new_tokens` 116 - `stop` with `stopping_criteria` 117 118 Args: 119 data: Input data for the LLM inference task. Either a pandas DataFrame (after signature 120 enforcement) or a raw dictionary payload. 121 params: Optional dictionary of parameters. 122 flavor_config: Optional dictionary of flavor configuration. 123 """ 124 if isinstance(data, pd.DataFrame): 125 # Pandas convert None to np.nan internally, which is not preferred 126 data = data.replace(np.nan, None).to_dict(orient="list") 127 elif isinstance(data, dict): 128 # Convert single value to list for consistency with DataFrame 129 data = {k: [v] for k, v in data.items()} 130 else: 131 raise MlflowException( 132 "Input data for a Transformer model logged with `llm/v1/chat` or `llm/v1/completions`" 133 f"task is expected to be a pandas DataFrame or a dictionary, but got: {type(data)}.", 134 error_code=BAD_REQUEST, 135 ) 136 137 flavor_config = flavor_config or {} 138 params = params or {} 139 140 # Extract list of input data (prompt, messages) to LLM 141 task = flavor_config[_LLM_INFERENCE_TASK_KEY] 142 input_col = _LLM_INFERENCE_TASK_TO_DATA_FIELD.get(task) 143 if input_col not in data: 144 raise MlflowException( 145 f"Transformer model saved with `{task}` task excepts `{input_col}`" 146 "to be passed as input data.", 147 error_code=BAD_REQUEST, 148 ) 149 update_data = data.pop(input_col) 150 151 # The rest of fields in input payload should goes to params and override default ones 152 params_in_data = {k: v[0] for k, v in data.items() if v[0] is not None} 153 params = params | params_in_data 154 155 if max_tokens := params.pop("max_tokens", None): 156 params["max_new_tokens"] = max_tokens 157 if stop := params.pop("stop", None): 158 params["stopping_criteria"] = _get_stopping_criteria( 159 stop, 160 flavor_config.get(FlavorKey.MODEL_NAME), 161 ) 162 return update_data, params 163 164 165 def _get_stopping_criteria(stop: str | list[str] | None, model_name: str | None = None): 166 """Return a list of Hugging Face stopping criteria objects for the given stop sequences.""" 167 from transformers import AutoTokenizer, StoppingCriteria 168 169 if stop is None or model_name is None: 170 return None 171 172 if isinstance(stop, str): 173 stop = [stop] 174 175 # To tokenize the stop sequences for stopping criteria, we need to use the slow tokenizer 176 # for matching the actual tokens, according to https://github.com/huggingface/transformers/issues/27704 177 tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 178 179 def _get_slow_token_ids(seq: str): 180 return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(seq)) 181 182 # NB: We need to define this as an inner class to avoid importing 183 # transformers in the global scope that confuses autologging 184 class _StopSequenceMatchCriteria(StoppingCriteria): 185 def __init__(self, stop_sequence_ids): 186 self.stop_sequence_ids = stop_sequence_ids 187 188 def __call__( 189 self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs 190 ) -> bool: 191 last_ids = input_ids[:, -len(self.stop_sequence_ids) :].tolist() 192 return self.stop_sequence_ids in last_ids 193 194 stopping_criteria = [] 195 for stop_sequence in stop: 196 # Add stopping criteria for both with and without space, such as "stopword" and " stopword" 197 token_ids = _get_slow_token_ids(stop_sequence) 198 token_ids_with_space = _get_slow_token_ids(" " + stop_sequence) 199 stopping_criteria += [ 200 _StopSequenceMatchCriteria(token_ids), 201 _StopSequenceMatchCriteria(token_ids_with_space), 202 ] 203 204 return stopping_criteria 205 206 207 def postprocess_output_for_llm_inference_task( 208 data: list[str], 209 output_tensors: list[list[int]], 210 pipeline, 211 flavor_config, 212 model_config, 213 inference_task, 214 ): 215 """ 216 Wrap output data with usage information according to the MLflow inference task. 217 218 Example: 219 .. code-block:: python 220 data = ["How to learn Python in 3 weeks?"] 221 output_tensors = [ 222 [ 223 1128, 224 304, 225 ..., 226 29879, 227 ] 228 ] 229 output_dicts = postprocess_output_for_llm_inference_task(data, output_tensors, **kwargs) 230 231 assert output_dicts == [ 232 { 233 "id": "e4f3b3e3-3b3e-4b3e-8b3e-3b3e4b3e8b3e", 234 "object": "text_completion", 235 "created": 1707466970, 236 "model": "loaded_model_name", 237 "choices": [ 238 { 239 "index": 0, 240 "finish_reason": "length", 241 "text": "1. Start with a beginner's", 242 } 243 ], 244 "usage": {"prompt_tokens": 9, "completion_tokens": 10, "total_tokens": 19}, 245 } 246 ] 247 248 Args: 249 data: List of text input prompts. 250 output_tensors: List of output tensors that contain the generated tokens (including 251 the prompt tokens) corresponding to each input prompt. 252 pipeline: The pipeline object used for inference. 253 flavor_config: The flavor configuration dictionary for the model. 254 model_config: The model configuration dictionary used for inference. 255 inference_task: The MLflow inference task. 256 257 Returns: 258 List of dictionaries containing the output text and usage information for each input prompt. 259 """ 260 return [ 261 _get_output_and_usage_from_tensor( 262 input_data, output_tensor, pipeline, flavor_config, model_config, inference_task 263 ) 264 for input_data, output_tensor in zip(data, output_tensors) 265 ] 266 267 268 def _get_output_and_usage_from_tensor( 269 prompt: str, output_tensor: list[int], pipeline, flavor_config, model_config, inference_task 270 ): 271 """ 272 Decode the output tensor and return the output text and usage information as a dictionary 273 to make the output in OpenAI compatible format. 274 """ 275 usage = _get_token_usage(prompt, output_tensor, pipeline, model_config) 276 completions_text = _get_completions_text(prompt, output_tensor, pipeline) 277 finish_reason = _get_finish_reason( 278 usage["total_tokens"], usage["completion_tokens"], model_config 279 ) 280 281 output_dict = { 282 "id": str(uuid.uuid4()), 283 "object": _LLM_INFERENCE_OBJECT_NAME[inference_task], 284 "created": int(time.time()), 285 "model": flavor_config.get("source_model_name", ""), 286 "usage": usage, 287 } 288 289 completion_choice = { 290 "index": 0, 291 "finish_reason": finish_reason, 292 } 293 294 if inference_task == _LLM_INFERENCE_TASK_COMPLETIONS: 295 completion_choice["text"] = completions_text 296 elif inference_task == _LLM_INFERENCE_TASK_CHAT: 297 completion_choice["message"] = {"role": "assistant", "content": completions_text} 298 299 output_dict["choices"] = [completion_choice] 300 301 return output_dict 302 303 304 def _get_completions_text(prompt: str, output_tensor: list[int], pipeline): 305 """Decode generated text from output tensor and remove the input prompt.""" 306 generated_text = pipeline.tokenizer.decode( 307 output_tensor, 308 skip_special_tokens=True, 309 clean_up_tokenization_spaces=True, 310 ) 311 312 # In order to correctly remove the prompt tokens from the decoded tokens, 313 # we need to acquire the length of the prompt without special tokens 314 # NB: `pipeline.framework` was removed in transformers 5.x. Fall back to "pt" since 315 # MLflow only supports PyTorch for transformers pipelines. 316 prompt_ids_without_special_tokens = pipeline.tokenizer( 317 prompt, return_tensors=getattr(pipeline, "framework", "pt"), add_special_tokens=False 318 )["input_ids"][0] 319 320 prompt_length = len( 321 pipeline.tokenizer.decode( 322 prompt_ids_without_special_tokens, 323 skip_special_tokens=True, 324 clean_up_tokenization_spaces=True, 325 ) 326 ) 327 328 return generated_text[prompt_length:].lstrip() 329 330 331 def _get_token_usage(prompt: str, output_tensor: list[int], pipeline, model_config): 332 """Return the prompt tokens, completion tokens, and the total tokens as dict.""" 333 inputs = pipeline.tokenizer( 334 prompt, 335 return_tensors=getattr(pipeline, "framework", "pt"), 336 max_length=model_config.get("max_length", None), 337 add_special_tokens=False, 338 ) 339 340 prompt_tokens = inputs["input_ids"].shape[-1] 341 total_tokens = len(output_tensor) 342 completions_tokens = total_tokens - prompt_tokens 343 344 return { 345 "prompt_tokens": prompt_tokens, 346 "completion_tokens": completions_tokens, 347 "total_tokens": total_tokens, 348 } 349 350 351 def _get_finish_reason(total_tokens: int, completion_tokens: int, model_config): 352 """Determine the reason that the text generation finished.""" 353 finish_reason = "stop" 354 355 if total_tokens > model_config.get( 356 "max_length", float("inf") 357 ) or completion_tokens == model_config.get("max_new_tokens", float("inf")): 358 finish_reason = "length" 359 360 return finish_reason 361 362 363 def _get_default_task_for_llm_inference_task(llm_inference_task: str | None) -> str | None: 364 """ 365 Get corresponding original Transformers task for the given LLM inference task. 366 367 NB: This assumes there is only one original Transformers task for each LLM inference 368 task, which might not be true in the future. 369 """ 370 for task, llm_tasks in _SUPPORTED_LLM_INFERENCE_TASK_TYPES_BY_PIPELINE_TASK.items(): 371 if llm_inference_task in llm_tasks: 372 return task 373 return None 374 375 376 def preprocess_llm_embedding_params( 377 data: pd.DataFrame | dict[str, Any], 378 ) -> tuple[list[str], dict[str, Any]]: 379 """ 380 When `llm/v1/embeddings` task is given, extract the input data (with "input" key) and 381 parameters, and format the input data into the unified format for easier downstream handling. 382 383 The handling is more complicated than other LLM inference tasks because the embedding endpoint 384 accepts heterogeneous input - both string and list of strings as input. Also we don't enforce 385 the input schema always, so there are 4 possible input types: 386 (1) Pandas DataFrame with string column 387 (2) Pandas DataFrame with list of strings column 388 (3) Dictionary with string value 389 (4) Dictionary with list of strings value 390 In all cases, the returned input data will be a list of strings. 391 392 Args: 393 data: Input data for the embedding task. 394 395 Returns: 396 Tuple of input data and parameters dictionary. 397 """ 398 if isinstance(data, pd.DataFrame): 399 params = {} 400 for col in data.columns: 401 if col == _LLM_V1_EMBEDDING_INPUT_KEY: 402 input_data = data[col].to_list() 403 if isinstance(input_data[0], list): 404 input_data = input_data[0] 405 else: 406 params[col] = data[col].tolist()[0] 407 else: 408 # NB: Input schema is not enforced for the embedding task because of the heterogeneous 409 # input type, so we have to cast the input data into unified format here. 410 input_data = data.get(_LLM_V1_EMBEDDING_INPUT_KEY) 411 if isinstance(input, str): 412 input_data = [input_data] 413 params = {k: v for k, v in data.items() if k != _LLM_V1_EMBEDDING_INPUT_KEY} 414 415 return input_data, params 416 417 418 def postprocess_output_for_llm_v1_embedding_task( 419 input_prompts: list[str], 420 output_tensors: list[list[float]], 421 tokenizer, 422 ): 423 """ 424 Wrap output data with usage information. 425 426 Examples: 427 .. code-block:: python 428 input_prompt = ["hello world and hello mlflow"] 429 output_embedding = [0.47137904, 0.4669448, ..., 0.69726706] 430 output_dicts = postprocess_output_for_llm_v1_embedding_task( 431 input_prompt, output_embedding 432 ) 433 assert output_dicts == [ 434 { 435 "object": "list", 436 "data": [ 437 { 438 "object": "embedding", 439 "index": 0, 440 "embedding": [0.47137904, 0.4669448, ..., 0.69726706], 441 } 442 ], 443 "usage": {"prompt_tokens": 8, "total_tokens": 8}, 444 } 445 ] 446 447 Args: 448 input_prompts: text input prompts 449 output_tensors: List of output tensors that contain the generated embeddings 450 tokenizer: The tokenizer object used for inference. 451 452 Returns: 453 Dictionaries containing the output embedding and usage information for each 454 input prompt. 455 """ 456 prompt_tokens = sum(len(tokenizer(prompt)["input_ids"]) for prompt in input_prompts) 457 return { 458 "object": "list", 459 "data": [ 460 { 461 "object": "embedding", 462 "index": i, 463 "embedding": tensor, 464 } 465 for i, tensor in enumerate(output_tensors) 466 ], 467 "usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}, 468 }