/ mlflow / gemini / chat.py
chat.py
  1  import json
  2  import logging
  3  from typing import TYPE_CHECKING
  4  
  5  from mlflow.types.chat import (
  6      ChatTool,
  7      Function,
  8      FunctionParams,
  9      FunctionToolDefinition,
 10      ParamProperty,
 11      ToolCall,
 12  )
 13  
 14  if TYPE_CHECKING:
 15      from google import genai
 16  
 17  _logger = logging.getLogger(__name__)
 18  
 19  
 20  def convert_gemini_func_to_mlflow_chat_tool(
 21      function_def: "genai.types.FunctionDeclaration",
 22  ) -> ChatTool:
 23      """
 24      Convert Gemini function definition into MLflow's standard format (OpenAI compatible).
 25      Ref: https://ai.google.dev/gemini-api/docs/function-calling
 26  
 27      Args:
 28          function_def: A genai.types.FunctionDeclaration or genai.protos.FunctionDeclaration object
 29                        representing a function definition.
 30  
 31      Returns:
 32          ChatTool: MLflow's standard tool definition object.
 33      """
 34      return ChatTool(
 35          type="function",
 36          function=FunctionToolDefinition(
 37              name=function_def.name,
 38              description=function_def.description,
 39              parameters=_convert_gemini_function_param_to_mlflow_function_param(
 40                  function_def.parameters
 41              ),
 42          ),
 43      )
 44  
 45  
 46  def convert_gemini_func_call_to_mlflow_tool_call(
 47      func_call: "genai.types.FunctionCall",
 48  ) -> ToolCall:
 49      """
 50      Convert Gemini function call into MLflow's standard format (OpenAI compatible).
 51      Ref: https://ai.google.dev/gemini-api/docs/function-calling
 52  
 53      Args:
 54          func_call: A genai.types.FunctionCall or genai.protos.FunctionCall object
 55                     representing a single func call.
 56  
 57      Returns:
 58          ToolCall: MLflow's standard tool call object.
 59      """
 60      # original args object is not json serializable
 61      args = func_call.args or {}
 62  
 63      return ToolCall(
 64          # Gemini does not have func call id
 65          id=func_call.name,
 66          type="function",
 67          function=Function(name=func_call.name, arguments=json.dumps(dict(args))),
 68      )
 69  
 70  
 71  def _convert_gemini_param_property_to_mlflow_param_property(param_property) -> ParamProperty:
 72      """
 73      Convert Gemini parameter property definition into MLflow's standard format (OpenAI compatible).
 74      Ref: https://ai.google.dev/gemini-api/docs/function-calling
 75  
 76      Args:
 77          param_property: A genai.types.Schema or genai.protos.Schema object
 78                          representing a parameter property.
 79  
 80      Returns:
 81          ParamProperty: MLflow's standard param property object.
 82      """
 83      type_name = param_property.type
 84      type_name = type_name.name.lower() if hasattr(type_name, "name") else type_name.lower()
 85      return ParamProperty(
 86          description=param_property.description,
 87          enum=param_property.enum,
 88          type=type_name,
 89      )
 90  
 91  
 92  def _convert_gemini_function_param_to_mlflow_function_param(
 93      function_params: "genai.types.Schema",
 94  ) -> FunctionParams:
 95      """
 96      Convert Gemini function parameter definition into MLflow's standard format (OpenAI compatible).
 97      Ref: https://ai.google.dev/gemini-api/docs/function-calling
 98  
 99      Args:
100          function_params: A genai.types.Schema or genai.protos.Schema object
101                           representing function parameters.
102  
103      Returns:
104          FunctionParams: MLflow's standard function parameter object.
105      """
106      return FunctionParams(
107          properties={
108              k: _convert_gemini_param_property_to_mlflow_param_property(v)
109              for k, v in function_params.properties.items()
110          },
111          required=function_params.required,
112      )