__init__.py
1 import os 2 3 from mlflow.deployments import BaseDeploymentClient 4 from mlflow.exceptions import MlflowException 5 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE 6 from mlflow.utils.openai_utils import ( 7 _OAITokenHolder, 8 _OpenAIApiConfig, 9 _OpenAIEnvVar, 10 ) 11 from mlflow.utils.rest_utils import augmented_raise_for_status 12 13 14 class OpenAIDeploymentClient(BaseDeploymentClient): 15 """ 16 Client for interacting with OpenAI endpoints. 17 18 Example: 19 20 First, set up credentials for authentication: 21 22 .. code-block:: bash 23 24 export OPENAI_API_KEY=... 25 26 .. seealso:: 27 28 See https://mlflow.org/docs/latest/python_api/openai/index.html for other authentication 29 methods. 30 31 Then, create a deployment client and use it to interact with OpenAI endpoints: 32 33 .. code-block:: python 34 35 from mlflow.deployments import get_deploy_client 36 37 client = get_deploy_client("openai") 38 client.predict( 39 endpoint="gpt-4o-mini", 40 inputs={ 41 "messages": [ 42 {"role": "user", "content": "Hello!"}, 43 ], 44 }, 45 ) 46 """ 47 48 def create_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None): 49 """ 50 .. warning:: 51 52 This method is not implemented for `OpenAIDeploymentClient`. 53 """ 54 raise NotImplementedError 55 56 def update_deployment(self, name, model_uri=None, flavor=None, config=None, endpoint=None): 57 """ 58 .. warning:: 59 60 This method is not implemented for `OpenAIDeploymentClient`. 61 """ 62 raise NotImplementedError 63 64 def delete_deployment(self, name, config=None, endpoint=None): 65 """ 66 .. warning:: 67 68 This method is not implemented for `OpenAIDeploymentClient`. 69 """ 70 raise NotImplementedError 71 72 def list_deployments(self, endpoint=None): 73 """ 74 .. warning:: 75 76 This method is not implemented for `OpenAIDeploymentClient`. 77 """ 78 raise NotImplementedError 79 80 def get_deployment(self, name, endpoint=None): 81 """ 82 .. warning:: 83 84 This method is not implemented for `OpenAIDeploymentClient`. 85 """ 86 raise NotImplementedError 87 88 def predict(self, deployment_name=None, inputs=None, endpoint=None): 89 """Query an OpenAI endpoint. 90 See https://platform.openai.com/docs/api-reference for more information. 91 92 Args: 93 deployment_name: Unused. 94 inputs: A dictionary containing the model inputs to query. 95 endpoint: The name of the endpoint to query. 96 97 Returns: 98 A dictionary containing the model outputs. 99 100 """ 101 _check_openai_key() 102 103 api_config = _get_api_config_without_openai_dep() 104 api_token = _OAITokenHolder(api_config.api_type) 105 api_token.refresh() 106 107 if api_config.api_type in ("azure", "azure_ad", "azuread"): 108 from openai import AzureOpenAI 109 110 client = AzureOpenAI( 111 api_key=api_token.token, 112 azure_endpoint=api_config.api_base, 113 api_version=api_config.api_version, 114 azure_deployment=api_config.deployment_id, 115 max_retries=api_config.max_retries, 116 timeout=api_config.timeout, 117 ) 118 else: 119 from openai import OpenAI 120 121 client = OpenAI( 122 api_key=api_token.token, 123 base_url=api_config.api_base, 124 max_retries=api_config.max_retries, 125 timeout=api_config.timeout, 126 ) 127 128 return client.chat.completions.create( 129 messages=inputs["messages"], model=endpoint 130 ).model_dump() 131 132 def create_endpoint(self, name, config=None): 133 """ 134 .. warning:: 135 136 This method is not implemented for `OpenAIDeploymentClient`. 137 """ 138 raise NotImplementedError 139 140 def update_endpoint(self, endpoint, config=None): 141 """ 142 .. warning:: 143 144 This method is not implemented for `OpenAIDeploymentClient`. 145 """ 146 raise NotImplementedError 147 148 def delete_endpoint(self, endpoint): 149 """ 150 .. warning:: 151 152 This method is not implemented for `OpenAIDeploymentClient`. 153 """ 154 raise NotImplementedError 155 156 def list_endpoints(self): 157 """ 158 List the currently available models. 159 """ 160 161 _check_openai_key() 162 163 api_config = _get_api_config_without_openai_dep() 164 import requests 165 166 if api_config.api_type in ("azure", "azure_ad", "azuread"): 167 raise NotImplementedError( 168 "List endpoints is not implemented for Azure OpenAI API", 169 ) 170 else: 171 api_key = os.environ["OPENAI_API_KEY"] 172 request_header = {"Authorization": f"Bearer {api_key}"} 173 174 response = requests.get( 175 "https://api.openai.com/v1/models", 176 headers=request_header, 177 ) 178 179 augmented_raise_for_status(response) 180 181 return response.json() 182 183 def get_endpoint(self, endpoint): 184 """ 185 Get information about a specific model. 186 """ 187 188 _check_openai_key() 189 190 api_config = _get_api_config_without_openai_dep() 191 import requests 192 193 if api_config.api_type in ("azure", "azure_ad", "azuread"): 194 raise NotImplementedError( 195 "Get endpoint is not implemented for Azure OpenAI API", 196 ) 197 else: 198 api_key = os.environ["OPENAI_API_KEY"] 199 request_header = {"Authorization": f"Bearer {api_key}"} 200 201 response = requests.get( 202 f"https://api.openai.com/v1/models/{endpoint}", 203 headers=request_header, 204 ) 205 206 augmented_raise_for_status(response) 207 208 return response.json() 209 210 211 def run_local(name, model_uri, flavor=None, config=None): 212 pass 213 214 215 def target_help(): 216 pass 217 218 219 def _get_api_config_without_openai_dep() -> _OpenAIApiConfig: 220 """ 221 Gets the parameters and configuration of the OpenAI API connected to. 222 """ 223 api_type = os.environ.get(_OpenAIEnvVar.OPENAI_API_TYPE.value) 224 api_version = os.environ.get(_OpenAIEnvVar.OPENAI_API_VERSION.value) 225 api_base = os.environ.get(_OpenAIEnvVar.OPENAI_API_BASE.value, None) 226 deployment_id = os.environ.get(_OpenAIEnvVar.OPENAI_DEPLOYMENT_NAME.value, None) 227 if api_type in ("azure", "azure_ad", "azuread"): 228 batch_size = 16 229 max_tokens_per_minute = 60_000 230 else: 231 # The maximum batch size is 2048: 232 # https://github.com/openai/openai-python/blob/b82a3f7e4c462a8a10fa445193301a3cefef9a4a/openai/embeddings_utils.py#L43 233 # We use a smaller batch size to be safe. 234 batch_size = 1024 235 max_tokens_per_minute = 90_000 236 return _OpenAIApiConfig( 237 api_type=api_type, 238 batch_size=batch_size, 239 max_requests_per_minute=3_500, 240 max_tokens_per_minute=max_tokens_per_minute, 241 api_base=api_base, 242 api_version=api_version, 243 deployment_id=deployment_id, 244 ) 245 246 247 def _check_openai_key(): 248 if "OPENAI_API_KEY" not in os.environ: 249 raise MlflowException( 250 "OPENAI_API_KEY environment variable not set", 251 error_code=INVALID_PARAMETER_VALUE, 252 )