/ mlflow / deployments / openai / __init__.py
__init__.py
  1  import os
  2  
  3  from mlflow.deployments import BaseDeploymentClient
  4  from mlflow.exceptions import MlflowException
  5  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
  6  from mlflow.utils.openai_utils import (
  7      _OAITokenHolder,
  8      _OpenAIApiConfig,
  9      _OpenAIEnvVar,
 10  )
 11  from mlflow.utils.rest_utils import augmented_raise_for_status
 12  
 13  
 14  class OpenAIDeploymentClient(BaseDeploymentClient):
 15      """
 16      Client for interacting with OpenAI endpoints.
 17  
 18      Example:
 19  
 20      First, set up credentials for authentication:
 21  
 22      .. code-block:: bash
 23  
 24          export OPENAI_API_KEY=...
 25  
 26      .. seealso::
 27  
 28          See https://mlflow.org/docs/latest/python_api/openai/index.html for other authentication
 29          methods.
 30  
 31      Then, create a deployment client and use it to interact with OpenAI endpoints:
 32  
 33      .. code-block:: python
 34  
 35          from mlflow.deployments import get_deploy_client
 36  
 37          client = get_deploy_client("openai")
 38          client.predict(
 39              endpoint="gpt-4o-mini",
 40              inputs={
 41                  "messages": [
 42                      {"role": "user", "content": "Hello!"},
 43                  ],
 44              },
 45          )
 46      """
 47  
 48      def create_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None):
 49          """
 50          .. warning::
 51  
 52              This method is not implemented for `OpenAIDeploymentClient`.
 53          """
 54          raise NotImplementedError
 55  
 56      def update_deployment(self, name, model_uri=None, flavor=None, config=None, endpoint=None):
 57          """
 58          .. warning::
 59  
 60              This method is not implemented for `OpenAIDeploymentClient`.
 61          """
 62          raise NotImplementedError
 63  
 64      def delete_deployment(self, name, config=None, endpoint=None):
 65          """
 66          .. warning::
 67  
 68              This method is not implemented for `OpenAIDeploymentClient`.
 69          """
 70          raise NotImplementedError
 71  
 72      def list_deployments(self, endpoint=None):
 73          """
 74          .. warning::
 75  
 76              This method is not implemented for `OpenAIDeploymentClient`.
 77          """
 78          raise NotImplementedError
 79  
 80      def get_deployment(self, name, endpoint=None):
 81          """
 82          .. warning::
 83  
 84              This method is not implemented for `OpenAIDeploymentClient`.
 85          """
 86          raise NotImplementedError
 87  
 88      def predict(self, deployment_name=None, inputs=None, endpoint=None):
 89          """Query an OpenAI endpoint.
 90          See https://platform.openai.com/docs/api-reference for more information.
 91  
 92          Args:
 93              deployment_name: Unused.
 94              inputs: A dictionary containing the model inputs to query.
 95              endpoint: The name of the endpoint to query.
 96  
 97          Returns:
 98              A dictionary containing the model outputs.
 99  
100          """
101          _check_openai_key()
102  
103          api_config = _get_api_config_without_openai_dep()
104          api_token = _OAITokenHolder(api_config.api_type)
105          api_token.refresh()
106  
107          if api_config.api_type in ("azure", "azure_ad", "azuread"):
108              from openai import AzureOpenAI
109  
110              client = AzureOpenAI(
111                  api_key=api_token.token,
112                  azure_endpoint=api_config.api_base,
113                  api_version=api_config.api_version,
114                  azure_deployment=api_config.deployment_id,
115                  max_retries=api_config.max_retries,
116                  timeout=api_config.timeout,
117              )
118          else:
119              from openai import OpenAI
120  
121              client = OpenAI(
122                  api_key=api_token.token,
123                  base_url=api_config.api_base,
124                  max_retries=api_config.max_retries,
125                  timeout=api_config.timeout,
126              )
127  
128          return client.chat.completions.create(
129              messages=inputs["messages"], model=endpoint
130          ).model_dump()
131  
132      def create_endpoint(self, name, config=None):
133          """
134          .. warning::
135  
136              This method is not implemented for `OpenAIDeploymentClient`.
137          """
138          raise NotImplementedError
139  
140      def update_endpoint(self, endpoint, config=None):
141          """
142          .. warning::
143  
144              This method is not implemented for `OpenAIDeploymentClient`.
145          """
146          raise NotImplementedError
147  
148      def delete_endpoint(self, endpoint):
149          """
150          .. warning::
151  
152              This method is not implemented for `OpenAIDeploymentClient`.
153          """
154          raise NotImplementedError
155  
156      def list_endpoints(self):
157          """
158          List the currently available models.
159          """
160  
161          _check_openai_key()
162  
163          api_config = _get_api_config_without_openai_dep()
164          import requests
165  
166          if api_config.api_type in ("azure", "azure_ad", "azuread"):
167              raise NotImplementedError(
168                  "List endpoints is not implemented for Azure OpenAI API",
169              )
170          else:
171              api_key = os.environ["OPENAI_API_KEY"]
172              request_header = {"Authorization": f"Bearer {api_key}"}
173  
174              response = requests.get(
175                  "https://api.openai.com/v1/models",
176                  headers=request_header,
177              )
178  
179              augmented_raise_for_status(response)
180  
181              return response.json()
182  
183      def get_endpoint(self, endpoint):
184          """
185          Get information about a specific model.
186          """
187  
188          _check_openai_key()
189  
190          api_config = _get_api_config_without_openai_dep()
191          import requests
192  
193          if api_config.api_type in ("azure", "azure_ad", "azuread"):
194              raise NotImplementedError(
195                  "Get endpoint is not implemented for Azure OpenAI API",
196              )
197          else:
198              api_key = os.environ["OPENAI_API_KEY"]
199              request_header = {"Authorization": f"Bearer {api_key}"}
200  
201              response = requests.get(
202                  f"https://api.openai.com/v1/models/{endpoint}",
203                  headers=request_header,
204              )
205  
206              augmented_raise_for_status(response)
207  
208              return response.json()
209  
210  
211  def run_local(name, model_uri, flavor=None, config=None):
212      pass
213  
214  
215  def target_help():
216      pass
217  
218  
219  def _get_api_config_without_openai_dep() -> _OpenAIApiConfig:
220      """
221      Gets the parameters and configuration of the OpenAI API connected to.
222      """
223      api_type = os.environ.get(_OpenAIEnvVar.OPENAI_API_TYPE.value)
224      api_version = os.environ.get(_OpenAIEnvVar.OPENAI_API_VERSION.value)
225      api_base = os.environ.get(_OpenAIEnvVar.OPENAI_API_BASE.value, None)
226      deployment_id = os.environ.get(_OpenAIEnvVar.OPENAI_DEPLOYMENT_NAME.value, None)
227      if api_type in ("azure", "azure_ad", "azuread"):
228          batch_size = 16
229          max_tokens_per_minute = 60_000
230      else:
231          # The maximum batch size is 2048:
232          # https://github.com/openai/openai-python/blob/b82a3f7e4c462a8a10fa445193301a3cefef9a4a/openai/embeddings_utils.py#L43
233          # We use a smaller batch size to be safe.
234          batch_size = 1024
235          max_tokens_per_minute = 90_000
236      return _OpenAIApiConfig(
237          api_type=api_type,
238          batch_size=batch_size,
239          max_requests_per_minute=3_500,
240          max_tokens_per_minute=max_tokens_per_minute,
241          api_base=api_base,
242          api_version=api_version,
243          deployment_id=deployment_id,
244      )
245  
246  
247  def _check_openai_key():
248      if "OPENAI_API_KEY" not in os.environ:
249          raise MlflowException(
250              "OPENAI_API_KEY environment variable not set",
251              error_code=INVALID_PARAMETER_VALUE,
252          )