/ mlflow / transformers / llm_inference_utils.py
llm_inference_utils.py
  1  from __future__ import annotations
  2  
  3  import time
  4  import uuid
  5  from typing import TYPE_CHECKING, Any
  6  
  7  import numpy as np
  8  import pandas as pd
  9  
 10  from mlflow.exceptions import MlflowException
 11  from mlflow.models import ModelSignature
 12  from mlflow.protos.databricks_pb2 import BAD_REQUEST, INVALID_PARAMETER_VALUE
 13  from mlflow.transformers.flavor_config import FlavorKey
 14  from mlflow.types.llm import (
 15      CHAT_MODEL_INPUT_SCHEMA,
 16      CHAT_MODEL_OUTPUT_SCHEMA,
 17      COMPLETIONS_MODEL_INPUT_SCHEMA,
 18      COMPLETIONS_MODEL_OUTPUT_SCHEMA,
 19      EMBEDDING_MODEL_INPUT_SCHEMA,
 20      EMBEDDING_MODEL_OUTPUT_SCHEMA,
 21  )
 22  
 23  if TYPE_CHECKING:
 24      import torch
 25  
 26  _LLM_INFERENCE_TASK_KEY = "inference_task"
 27  # The LLM inference task is saved as "task" in the metadata for forward compatibility with
 28  # future Databricks Provisioned Throughput support of more model architectures for inference.
 29  _METADATA_LLM_INFERENCE_TASK_KEY = "task"
 30  
 31  _LLM_INFERENCE_TASK_PREFIX = "llm/v1"
 32  _LLM_INFERENCE_TASK_COMPLETIONS = f"{_LLM_INFERENCE_TASK_PREFIX}/completions"
 33  _LLM_INFERENCE_TASK_CHAT = f"{_LLM_INFERENCE_TASK_PREFIX}/chat"
 34  _LLM_INFERENCE_TASK_EMBEDDING = f"{_LLM_INFERENCE_TASK_PREFIX}/embeddings"
 35  
 36  _LLM_V1_EMBEDDING_INPUT_KEY = "input"
 37  
 38  
 39  _LLM_INFERENCE_OBJECT_NAME = {
 40      _LLM_INFERENCE_TASK_COMPLETIONS: "text_completion",
 41      _LLM_INFERENCE_TASK_CHAT: "chat.completion",
 42  }
 43  
 44  _SUPPORTED_LLM_INFERENCE_TASK_TYPES_BY_PIPELINE_TASK = {
 45      "text-generation": [_LLM_INFERENCE_TASK_COMPLETIONS, _LLM_INFERENCE_TASK_CHAT],
 46      "feature-extraction": [_LLM_INFERENCE_TASK_EMBEDDING],
 47  }
 48  
 49  _SIGNATURE_FOR_LLM_INFERENCE_TASK = {
 50      _LLM_INFERENCE_TASK_CHAT: ModelSignature(
 51          inputs=CHAT_MODEL_INPUT_SCHEMA, outputs=CHAT_MODEL_OUTPUT_SCHEMA
 52      ),
 53      _LLM_INFERENCE_TASK_COMPLETIONS: ModelSignature(
 54          inputs=COMPLETIONS_MODEL_INPUT_SCHEMA, outputs=COMPLETIONS_MODEL_OUTPUT_SCHEMA
 55      ),
 56      _LLM_INFERENCE_TASK_EMBEDDING: ModelSignature(
 57          inputs=EMBEDDING_MODEL_INPUT_SCHEMA, outputs=EMBEDDING_MODEL_OUTPUT_SCHEMA
 58      ),
 59  }
 60  
 61  _LLM_INFERENCE_TASK_TO_DATA_FIELD = {
 62      _LLM_INFERENCE_TASK_CHAT: "messages",
 63      _LLM_INFERENCE_TASK_COMPLETIONS: "prompt",
 64  }
 65  
 66  
 67  def infer_signature_from_llm_inference_task(
 68      inference_task: str, signature: ModelSignature | None = None
 69  ) -> ModelSignature:
 70      """
 71      Infers the signature according to the MLflow inference task.
 72      Raises exception if a signature is given.
 73      """
 74      inferred_signature = _SIGNATURE_FOR_LLM_INFERENCE_TASK[inference_task]
 75  
 76      if signature is not None and signature != inferred_signature:
 77          raise MlflowException(
 78              f"When `task` is specified as `{inference_task}`, the signature would "
 79              "be set by MLflow. Please do not set the signature."
 80          )
 81      return inferred_signature
 82  
 83  
 84  def convert_messages_to_prompt(messages: list[dict[str, Any]], tokenizer) -> str:
 85      """For the Chat inference task, apply chat template to messages to create prompt.
 86  
 87      Args:
 88          messages: List of message e.g. [{"role": user, "content": xxx}, ...]
 89          tokenizer: The tokenizer object used for inference.
 90  
 91      Returns:
 92          The prompt string contains the messages.
 93      """
 94      if not (isinstance(messages, list) and all(isinstance(msg, dict) for msg in messages)):
 95          raise MlflowException(
 96              f"Input messages should be list of dictionaries, but got: {type(messages)}.",
 97              error_code=INVALID_PARAMETER_VALUE,
 98          )
 99  
100      try:
101          return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
102      except Exception as e:
103          raise MlflowException(f"Failed to apply chat template: {e}")
104  
105  
106  def preprocess_llm_inference_input(
107      data: pd.DataFrame | dict[str, Any],
108      params: dict[str, Any] | None = None,
109      flavor_config: dict[str, Any] | None = None,
110  ) -> tuple[list[Any], dict[str, Any]]:
111      """
112      When a MLflow inference task is given, return updated `data` and `params` that
113      - Extract the parameters from the input data (from the first row if passed multiple rows)
114      - Replace OpenAI specific parameters with Hugging Face specific parameters, in particular
115        - `max_tokens` with `max_new_tokens`
116        - `stop` with `stopping_criteria`
117  
118      Args:
119          data: Input data for the LLM inference task. Either a pandas DataFrame (after signature
120              enforcement) or a raw dictionary payload.
121          params: Optional dictionary of parameters.
122          flavor_config: Optional dictionary of flavor configuration.
123      """
124      if isinstance(data, pd.DataFrame):
125          # Pandas convert None to np.nan internally, which is not preferred
126          data = data.replace(np.nan, None).to_dict(orient="list")
127      elif isinstance(data, dict):
128          # Convert single value to list for consistency with DataFrame
129          data = {k: [v] for k, v in data.items()}
130      else:
131          raise MlflowException(
132              "Input data for a Transformer model logged with `llm/v1/chat` or `llm/v1/completions`"
133              f"task is expected to be a pandas DataFrame or a dictionary, but got: {type(data)}.",
134              error_code=BAD_REQUEST,
135          )
136  
137      flavor_config = flavor_config or {}
138      params = params or {}
139  
140      # Extract list of input data (prompt, messages) to LLM
141      task = flavor_config[_LLM_INFERENCE_TASK_KEY]
142      input_col = _LLM_INFERENCE_TASK_TO_DATA_FIELD.get(task)
143      if input_col not in data:
144          raise MlflowException(
145              f"Transformer model saved with `{task}` task excepts `{input_col}`"
146              "to be passed as input data.",
147              error_code=BAD_REQUEST,
148          )
149      update_data = data.pop(input_col)
150  
151      # The rest of fields in input payload should goes to params and override default ones
152      params_in_data = {k: v[0] for k, v in data.items() if v[0] is not None}
153      params = params | params_in_data
154  
155      if max_tokens := params.pop("max_tokens", None):
156          params["max_new_tokens"] = max_tokens
157      if stop := params.pop("stop", None):
158          params["stopping_criteria"] = _get_stopping_criteria(
159              stop,
160              flavor_config.get(FlavorKey.MODEL_NAME),
161          )
162      return update_data, params
163  
164  
165  def _get_stopping_criteria(stop: str | list[str] | None, model_name: str | None = None):
166      """Return a list of Hugging Face stopping criteria objects for the given stop sequences."""
167      from transformers import AutoTokenizer, StoppingCriteria
168  
169      if stop is None or model_name is None:
170          return None
171  
172      if isinstance(stop, str):
173          stop = [stop]
174  
175      # To tokenize the stop sequences for stopping criteria, we need to use the slow tokenizer
176      # for matching the actual tokens, according to https://github.com/huggingface/transformers/issues/27704
177      tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
178  
179      def _get_slow_token_ids(seq: str):
180          return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(seq))
181  
182      # NB: We need to define this as an inner class to avoid importing
183      # transformers in the global scope that confuses autologging
184      class _StopSequenceMatchCriteria(StoppingCriteria):
185          def __init__(self, stop_sequence_ids):
186              self.stop_sequence_ids = stop_sequence_ids
187  
188          def __call__(
189              self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
190          ) -> bool:
191              last_ids = input_ids[:, -len(self.stop_sequence_ids) :].tolist()
192              return self.stop_sequence_ids in last_ids
193  
194      stopping_criteria = []
195      for stop_sequence in stop:
196          # Add stopping criteria for both with and without space, such as "stopword" and " stopword"
197          token_ids = _get_slow_token_ids(stop_sequence)
198          token_ids_with_space = _get_slow_token_ids(" " + stop_sequence)
199          stopping_criteria += [
200              _StopSequenceMatchCriteria(token_ids),
201              _StopSequenceMatchCriteria(token_ids_with_space),
202          ]
203  
204      return stopping_criteria
205  
206  
207  def postprocess_output_for_llm_inference_task(
208      data: list[str],
209      output_tensors: list[list[int]],
210      pipeline,
211      flavor_config,
212      model_config,
213      inference_task,
214  ):
215      """
216      Wrap output data with usage information according to the MLflow inference task.
217  
218      Example:
219          .. code-block:: python
220              data = ["How to learn Python in 3 weeks?"]
221              output_tensors = [
222                  [
223                      1128,
224                      304,
225                      ...,
226                      29879,
227                  ]
228              ]
229              output_dicts = postprocess_output_for_llm_inference_task(data, output_tensors, **kwargs)
230  
231              assert output_dicts == [
232                  {
233                      "id": "e4f3b3e3-3b3e-4b3e-8b3e-3b3e4b3e8b3e",
234                      "object": "text_completion",
235                      "created": 1707466970,
236                      "model": "loaded_model_name",
237                      "choices": [
238                          {
239                              "index": 0,
240                              "finish_reason": "length",
241                              "text": "1. Start with a beginner's",
242                          }
243                      ],
244                      "usage": {"prompt_tokens": 9, "completion_tokens": 10, "total_tokens": 19},
245                  }
246              ]
247  
248      Args:
249          data: List of text input prompts.
250          output_tensors: List of output tensors that contain the generated tokens (including
251              the prompt tokens) corresponding to each input prompt.
252          pipeline: The pipeline object used for inference.
253          flavor_config: The flavor configuration dictionary for the model.
254          model_config: The model configuration dictionary used for inference.
255          inference_task: The MLflow inference task.
256  
257      Returns:
258          List of dictionaries containing the output text and usage information for each input prompt.
259      """
260      return [
261          _get_output_and_usage_from_tensor(
262              input_data, output_tensor, pipeline, flavor_config, model_config, inference_task
263          )
264          for input_data, output_tensor in zip(data, output_tensors)
265      ]
266  
267  
268  def _get_output_and_usage_from_tensor(
269      prompt: str, output_tensor: list[int], pipeline, flavor_config, model_config, inference_task
270  ):
271      """
272      Decode the output tensor and return the output text and usage information as a dictionary
273      to make the output in OpenAI compatible format.
274      """
275      usage = _get_token_usage(prompt, output_tensor, pipeline, model_config)
276      completions_text = _get_completions_text(prompt, output_tensor, pipeline)
277      finish_reason = _get_finish_reason(
278          usage["total_tokens"], usage["completion_tokens"], model_config
279      )
280  
281      output_dict = {
282          "id": str(uuid.uuid4()),
283          "object": _LLM_INFERENCE_OBJECT_NAME[inference_task],
284          "created": int(time.time()),
285          "model": flavor_config.get("source_model_name", ""),
286          "usage": usage,
287      }
288  
289      completion_choice = {
290          "index": 0,
291          "finish_reason": finish_reason,
292      }
293  
294      if inference_task == _LLM_INFERENCE_TASK_COMPLETIONS:
295          completion_choice["text"] = completions_text
296      elif inference_task == _LLM_INFERENCE_TASK_CHAT:
297          completion_choice["message"] = {"role": "assistant", "content": completions_text}
298  
299      output_dict["choices"] = [completion_choice]
300  
301      return output_dict
302  
303  
304  def _get_completions_text(prompt: str, output_tensor: list[int], pipeline):
305      """Decode generated text from output tensor and remove the input prompt."""
306      generated_text = pipeline.tokenizer.decode(
307          output_tensor,
308          skip_special_tokens=True,
309          clean_up_tokenization_spaces=True,
310      )
311  
312      # In order to correctly remove the prompt tokens from the decoded tokens,
313      # we need to acquire the length of the prompt without special tokens
314      # NB: `pipeline.framework` was removed in transformers 5.x. Fall back to "pt" since
315      # MLflow only supports PyTorch for transformers pipelines.
316      prompt_ids_without_special_tokens = pipeline.tokenizer(
317          prompt, return_tensors=getattr(pipeline, "framework", "pt"), add_special_tokens=False
318      )["input_ids"][0]
319  
320      prompt_length = len(
321          pipeline.tokenizer.decode(
322              prompt_ids_without_special_tokens,
323              skip_special_tokens=True,
324              clean_up_tokenization_spaces=True,
325          )
326      )
327  
328      return generated_text[prompt_length:].lstrip()
329  
330  
331  def _get_token_usage(prompt: str, output_tensor: list[int], pipeline, model_config):
332      """Return the prompt tokens, completion tokens, and the total tokens as dict."""
333      inputs = pipeline.tokenizer(
334          prompt,
335          return_tensors=getattr(pipeline, "framework", "pt"),
336          max_length=model_config.get("max_length", None),
337          add_special_tokens=False,
338      )
339  
340      prompt_tokens = inputs["input_ids"].shape[-1]
341      total_tokens = len(output_tensor)
342      completions_tokens = total_tokens - prompt_tokens
343  
344      return {
345          "prompt_tokens": prompt_tokens,
346          "completion_tokens": completions_tokens,
347          "total_tokens": total_tokens,
348      }
349  
350  
351  def _get_finish_reason(total_tokens: int, completion_tokens: int, model_config):
352      """Determine the reason that the text generation finished."""
353      finish_reason = "stop"
354  
355      if total_tokens > model_config.get(
356          "max_length", float("inf")
357      ) or completion_tokens == model_config.get("max_new_tokens", float("inf")):
358          finish_reason = "length"
359  
360      return finish_reason
361  
362  
363  def _get_default_task_for_llm_inference_task(llm_inference_task: str | None) -> str | None:
364      """
365      Get corresponding original Transformers task for the given LLM inference task.
366  
367      NB: This assumes there is only one original Transformers task for each LLM inference
368        task, which might not be true in the future.
369      """
370      for task, llm_tasks in _SUPPORTED_LLM_INFERENCE_TASK_TYPES_BY_PIPELINE_TASK.items():
371          if llm_inference_task in llm_tasks:
372              return task
373      return None
374  
375  
376  def preprocess_llm_embedding_params(
377      data: pd.DataFrame | dict[str, Any],
378  ) -> tuple[list[str], dict[str, Any]]:
379      """
380      When `llm/v1/embeddings` task is given, extract the input data (with "input" key) and
381      parameters, and format the input data into the unified format for easier downstream handling.
382  
383      The handling is more complicated than other LLM inference tasks because the embedding endpoint
384      accepts heterogeneous input - both string and list of strings as input. Also we don't enforce
385      the input schema always, so there are 4 possible input types:
386        (1)  Pandas DataFrame with string column
387        (2)  Pandas DataFrame with list of strings column
388        (3)  Dictionary with string value
389        (4)  Dictionary with list of strings value
390      In all cases, the returned input data will be a list of strings.
391  
392      Args:
393          data: Input data for the embedding task.
394  
395      Returns:
396          Tuple of input data and parameters dictionary.
397      """
398      if isinstance(data, pd.DataFrame):
399          params = {}
400          for col in data.columns:
401              if col == _LLM_V1_EMBEDDING_INPUT_KEY:
402                  input_data = data[col].to_list()
403                  if isinstance(input_data[0], list):
404                      input_data = input_data[0]
405              else:
406                  params[col] = data[col].tolist()[0]
407      else:
408          # NB: Input schema is not enforced for the embedding task because of the heterogeneous
409          # input type, so we have to cast the input data into unified format here.
410          input_data = data.get(_LLM_V1_EMBEDDING_INPUT_KEY)
411          if isinstance(input, str):
412              input_data = [input_data]
413          params = {k: v for k, v in data.items() if k != _LLM_V1_EMBEDDING_INPUT_KEY}
414  
415      return input_data, params
416  
417  
418  def postprocess_output_for_llm_v1_embedding_task(
419      input_prompts: list[str],
420      output_tensors: list[list[float]],
421      tokenizer,
422  ):
423      """
424      Wrap output data with usage information.
425  
426      Examples:
427          .. code-block:: python
428              input_prompt = ["hello world and hello mlflow"]
429              output_embedding = [0.47137904, 0.4669448, ..., 0.69726706]
430              output_dicts = postprocess_output_for_llm_v1_embedding_task(
431                  input_prompt, output_embedding
432              )
433              assert output_dicts == [
434                  {
435                      "object": "list",
436                      "data": [
437                          {
438                              "object": "embedding",
439                              "index": 0,
440                              "embedding": [0.47137904, 0.4669448, ..., 0.69726706],
441                          }
442                      ],
443                      "usage": {"prompt_tokens": 8, "total_tokens": 8},
444                  }
445              ]
446  
447      Args:
448          input_prompts: text input prompts
449          output_tensors: List of output tensors that contain the generated embeddings
450          tokenizer: The tokenizer object used for inference.
451  
452      Returns:
453              Dictionaries containing the output embedding and usage information for each
454              input prompt.
455      """
456      prompt_tokens = sum(len(tokenizer(prompt)["input_ids"]) for prompt in input_prompts)
457      return {
458          "object": "list",
459          "data": [
460              {
461                  "object": "embedding",
462                  "index": i,
463                  "embedding": tensor,
464              }
465              for i, tensor in enumerate(output_tensors)
466          ],
467          "usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
468      }