/ mlflow / sagemaker / __init__.py
__init__.py
   1  """
   2  The ``mlflow.sagemaker`` module provides an API for deploying MLflow models to Amazon SageMaker.
   3  """
   4  
   5  import json
   6  import logging
   7  import os
   8  import signal
   9  import subprocess
  10  import sys
  11  import tarfile
  12  import time
  13  import urllib.parse
  14  import uuid
  15  from typing import Any
  16  
  17  import mlflow
  18  import mlflow.version
  19  from mlflow import pyfunc
  20  from mlflow.deployments import BaseDeploymentClient, PredictionsResponse
  21  from mlflow.environment_variables import (
  22      MLFLOW_DEPLOYMENT_FLAVOR_NAME,
  23      MLFLOW_SAGEMAKER_DEPLOY_IMG_URL,
  24  )
  25  from mlflow.exceptions import MlflowException
  26  from mlflow.models import Model
  27  from mlflow.models.container import (
  28      SERVING_ENVIRONMENT,
  29  )
  30  from mlflow.models.container import (
  31      SUPPORTED_FLAVORS as SUPPORTED_DEPLOYMENT_FLAVORS,
  32  )
  33  from mlflow.models.model import MLMODEL_FILE_NAME
  34  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST
  35  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
  36  from mlflow.utils.file_utils import TempDir
  37  from mlflow.utils.proto_json_utils import dump_input_data
  38  
  39  DEFAULT_IMAGE_NAME = "mlflow-pyfunc"
  40  DEPLOYMENT_MODE_ADD = "add"
  41  DEPLOYMENT_MODE_REPLACE = "replace"
  42  DEPLOYMENT_MODE_CREATE = "create"
  43  
  44  DEPLOYMENT_MODES = [DEPLOYMENT_MODE_CREATE, DEPLOYMENT_MODE_ADD, DEPLOYMENT_MODE_REPLACE]
  45  
  46  DEFAULT_BUCKET_NAME_PREFIX = "mlflow-sagemaker"
  47  
  48  DEFAULT_SAGEMAKER_INSTANCE_TYPE = "ml.m4.xlarge"
  49  DEFAULT_SAGEMAKER_INSTANCE_COUNT = 1
  50  
  51  DEFAULT_REGION_NAME = "us-west-2"
  52  SAGEMAKER_SERVING_ENVIRONMENT = "SageMaker"
  53  SAGEMAKER_APP_NAME_TAG_KEY = "app_name"
  54  
  55  _logger = logging.getLogger(__name__)
  56  
  57  _full_template = "{account}.dkr.ecr.{region}.amazonaws.com/{image}:{version}"
  58  
  59  
  60  def _get_preferred_deployment_flavor(model_config):
  61      """
  62      Obtains the flavor that MLflow would prefer to use when deploying the model.
  63      If the model does not contain any supported flavors for deployment, an exception
  64      will be thrown.
  65  
  66      Args:
  67          model_config: An MLflow model object
  68  
  69      Returns:
  70          The name of the preferred deployment flavor for the specified model
  71      """
  72      if pyfunc.FLAVOR_NAME in model_config.flavors:
  73          return pyfunc.FLAVOR_NAME
  74      else:
  75          raise MlflowException(
  76              message=(
  77                  "The specified model does not contain any of the supported flavors for"
  78                  " deployment. The model contains the following flavors: {model_flavors}."
  79                  " Supported flavors: {supported_flavors}".format(
  80                      model_flavors=model_config.flavors.keys(),
  81                      supported_flavors=SUPPORTED_DEPLOYMENT_FLAVORS,
  82                  )
  83              ),
  84              error_code=RESOURCE_DOES_NOT_EXIST,
  85          )
  86  
  87  
  88  def _validate_deployment_flavor(model_config, flavor):
  89      """
  90      Checks that the specified flavor is a supported deployment flavor
  91      and is contained in the specified model. If one of these conditions
  92      is not met, an exception is thrown.
  93  
  94      Args:
  95          model_config: An MLflow Model object
  96          flavor: The deployment flavor to validate
  97      """
  98      if flavor not in SUPPORTED_DEPLOYMENT_FLAVORS:
  99          raise MlflowException(
 100              message=(
 101                  f"The specified flavor: `{flavor}` is not supported for deployment."
 102                  f" Please use one of the supported flavors: {SUPPORTED_DEPLOYMENT_FLAVORS}"
 103              ),
 104              error_code=INVALID_PARAMETER_VALUE,
 105          )
 106      elif flavor not in model_config.flavors:
 107          raise MlflowException(
 108              message=(
 109                  "The specified model does not contain the specified deployment flavor:"
 110                  f" `{flavor}`. Please use one of the following deployment flavors"
 111                  f" that the model contains: {model_config.flavors.keys()}"
 112              ),
 113              error_code=RESOURCE_DOES_NOT_EXIST,
 114          )
 115  
 116  
 117  def push_image_to_ecr(image=DEFAULT_IMAGE_NAME):
 118      """
 119      Push local Docker image to AWS ECR.
 120  
 121      The image is pushed under currently active AWS account and to the currently active AWS region.
 122  
 123      Args:
 124          image: Docker image name.
 125      """
 126      import boto3
 127  
 128      _logger.info("Pushing image to ECR")
 129      client = boto3.client("sts")
 130      caller_id = client.get_caller_identity()
 131      account = caller_id["Account"]
 132      my_session = boto3.session.Session()
 133      region = my_session.region_name or "us-west-2"
 134      fullname = _full_template.format(
 135          account=account, region=region, image=image, version=mlflow.version.VERSION
 136      )
 137      _logger.info("Pushing docker image %s to %s", image, fullname)
 138      ecr_client = boto3.client("ecr")
 139      try:
 140          ecr_client.describe_repositories(repositoryNames=[image])["repositories"]
 141      except ecr_client.exceptions.RepositoryNotFoundException:
 142          ecr_client.create_repository(repositoryName=image)
 143          _logger.info("Created new ECR repository: %s", image)
 144      registry = f"{account}.dkr.ecr.{region}.amazonaws.com"
 145  
 146      try:
 147          # Docker login: get password from AWS CLI and pipe to docker login
 148          _logger.info("Logging in to ECR registry: %s", registry)
 149          aws_result = subprocess.run(
 150              ["aws", "ecr", "get-login-password"],
 151              capture_output=True,
 152              check=True,
 153          )
 154          subprocess.run(
 155              ["docker", "login", "--username", "AWS", "--password-stdin", registry],
 156              input=aws_result.stdout,
 157              check=True,
 158          )
 159  
 160          # Docker tag
 161          _logger.info("Tagging image %s as %s", image, fullname)
 162          subprocess.check_call(["docker", "tag", image, fullname])
 163  
 164          # Docker push
 165          _logger.info("Pushing image %s", fullname)
 166          subprocess.check_call(["docker", "push", fullname])
 167      except subprocess.CalledProcessError as e:
 168          cmd = " ".join(e.cmd)
 169          raise MlflowException(
 170              f"Failed to push image to ECR. Command '{cmd}' failed with exit code {e.returncode}"
 171          ) from e
 172  
 173  
 174  def _deploy(
 175      app_name,
 176      model_uri,
 177      execution_role_arn=None,
 178      assume_role_arn=None,
 179      bucket=None,
 180      image_url=None,
 181      region_name="us-west-2",
 182      mode=DEPLOYMENT_MODE_CREATE,
 183      archive=False,
 184      instance_type=DEFAULT_SAGEMAKER_INSTANCE_TYPE,
 185      instance_count=DEFAULT_SAGEMAKER_INSTANCE_COUNT,
 186      vpc_config=None,
 187      flavor=None,
 188      synchronous=True,
 189      timeout_seconds=1200,
 190      data_capture_config=None,
 191      variant_name=None,
 192      async_inference_config=None,
 193      serverless_config=None,
 194      env=None,
 195      tags=None,
 196  ):
 197      """
 198      Deploy an MLflow model on AWS SageMaker.
 199      The currently active AWS account must have correct permissions set up.
 200  
 201      This function creates a SageMaker endpoint. For more information about the input data
 202      formats accepted by this endpoint, see the
 203      `MLflow deployment tools documentation <../../deployment/deploy-model-to-sagemaker.html>`_.
 204  
 205      Args:
 206          app_name: Name of the deployed application.
 207          model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker.
 208              For example:
 209  
 210              - ``/Users/me/path/to/local/model``
 211              - ``relative/path/to/local/model``
 212              - ``s3://my_bucket/path/to/model``
 213              - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
 214              - ``models:/<model_name>/<model_version>``
 215              - ``models:/<model_name>/<stage>``
 216  
 217              For more information about supported URI schemes, see
 218              `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
 219              artifact-locations>`_.
 220  
 221          execution_role_arn: The name of an IAM role granting the SageMaker service permissions to
 222              access the specified Docker image and S3 bucket containing MLflow
 223              model artifacts. If unspecified, the currently-assumed role will be
 224              used. This execution role is passed to the SageMaker service when
 225              creating a SageMaker model from the specified MLflow model. It is
 226              passed as the ``ExecutionRoleArn`` parameter of the `SageMaker
 227              CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/
 228              dg/API_CreateModel.html>`_. This role is *not* assumed for any other
 229              call. For more information about SageMaker execution roles for model
 230              creation, see
 231              https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html.
 232          assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker
 233              to another AWS account. If unspecified, SageMaker will be deployed to
 234              the the currently active AWS account.
 235          bucket: S3 bucket where model artifacts will be stored. Defaults to a
 236              SageMaker-compatible bucket name.
 237          image_url: URL of the ECR-hosted Docker image the model should be deployed into, produced
 238              by ``mlflow sagemaker build-and-push-container``. This parameter can also
 239              be specified by the environment variable ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``.
 240          region_name: Name of the AWS region to which to deploy the application.
 241          mode: The mode in which to deploy the application. Must be one of the following:
 242  
 243              ``mlflow.sagemaker.DEPLOYMENT_MODE_CREATE``
 244                  Create an application with the specified name and model. This fails if an
 245                  application of the same name already exists.
 246  
 247              ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE``
 248                  If an application of the specified name exists, its model(s) is replaced with
 249                  the specified model. If no such application exists, it is created with the
 250                  specified name and model.
 251  
 252              ``mlflow.sagemaker.DEPLOYMENT_MODE_ADD``
 253                  Add the specified model to a pre-existing application with the specified name,
 254                  if one exists. If the application does not exist, a new application is created
 255                  with the specified name and model. NOTE: If the application **already exists**,
 256                  the specified model is added to the application's corresponding SageMaker
 257                  endpoint with an initial weight of zero (0). To route traffic to the model,
 258                  update the application's associated endpoint configuration using either the
 259                  AWS console or the ``UpdateEndpointWeightsAndCapacities`` function defined in
 260                  https://docs.aws.amazon.com/sagemaker/latest/dg/API_UpdateEndpointWeightsAndCapacities.html.
 261  
 262          archive: If ``True``, any pre-existing SageMaker application resources that become
 263              inactive (i.e. as a result of deploying in
 264              ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE`` mode) are preserved.
 265              These resources may include unused SageMaker models and endpoint configurations
 266              that were associated with a prior version of the application endpoint. If
 267              ``False``, these resources are deleted. In order to use ``archive=False``,
 268              ``deploy()`` must be executed synchronously with ``synchronous=True``.
 269          instance_type: The type of SageMaker ML instance on which to deploy the model. For a list
 270              of supported instance types, see
 271              https://aws.amazon.com/sagemaker/pricing/instance-types/.
 272          instance_count: The number of SageMaker ML instances on which to deploy the model.
 273          vpc_config: A dictionary specifying the VPC configuration to use when creating the
 274              new SageMaker model associated with this application. The acceptable values
 275              for this parameter are identical to those of the ``VpcConfig`` parameter in
 276              the `SageMaker boto3 client's create_model method
 277              <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html
 278              #SageMaker.Client.create_model>`_. For more information, see
 279              https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html.
 280  
 281              .. code-block:: python
 282                  :caption: Example
 283  
 284                      import mlflow.sagemaker as mfs
 285  
 286                      vpc_config = {
 287                          "SecurityGroupIds": [
 288                              "sg-123456abc",
 289                          ],
 290                          "Subnets": [
 291                              "subnet-123456abc",
 292                          ],
 293                      }
 294                      mfs._deploy(..., vpc_config=vpc_config)
 295  
 296          flavor: The name of the flavor of the model to use for deployment. Must be either
 297              ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS. If ``None``,
 298              a flavor is automatically selected from the model's available flavors. If the
 299              specified flavor is not present or not supported for deployment, an exception
 300              will be thrown.
 301          synchronous: If ``True``, this function will block until the deployment process succeeds
 302              or encounters an irrecoverable failure. If ``False``, this function will
 303              return immediately after starting the deployment process. It will not wait
 304              for the deployment process to complete; in this case, the caller is
 305              responsible for monitoring the health and status of the pending deployment
 306              via native SageMaker APIs or the AWS console.
 307          timeout_seconds: If ``synchronous`` is ``True``, the deployment process will return after
 308              the specified number of seconds if no definitive result (success or
 309              failure) is achieved. Once the function returns, the caller is
 310              responsible for monitoring the health and status of the pending
 311              deployment using native SageMaker APIs or the AWS console. If
 312              ``synchronous`` is ``False``, this parameter is ignored.
 313          data_capture_config: A dictionary specifying the data capture configuration to use when
 314              creating the new SageMaker model associated with this application.
 315              For more information, see
 316              https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html.
 317  
 318              .. code-block:: python
 319                  :caption: Example
 320  
 321                  import mlflow.sagemaker as mfs
 322  
 323                  data_capture_config = {
 324                      "EnableCapture": True,
 325                      "InitialSamplingPercentage": 100,
 326                      "DestinationS3Uri": "s3://my-bucket/path",
 327                      "CaptureOptions": [{"CaptureMode": "Output"}],
 328                  }
 329                  mfs._deploy(..., data_capture_config=data_capture_config)
 330  
 331          variant_name: The name to assign to the new production variant.
 332          async_inference_config: The name to assign to the endpoint_config
 333              on the sagemaker endpoint.
 334              .. code-block:: python
 335                  :caption: Example
 336  
 337                  {
 338                      "AsyncInferenceConfig": {
 339                          "ClientConfig": {"MaxConcurrentInvocationsPerInstance": 4},
 340                          "OutputConfig": {
 341                              "S3OutputPath": "s3://<path-to-output-bucket>",
 342                              "NotificationConfig": {},
 343                          },
 344                      }
 345                  }
 346  
 347          serverless_config: An optional dictionary specifying the serverless configuration
 348              .. code-block:: python
 349                  :caption: Example
 350  
 351                  {
 352                      "ServerlessConfig": {
 353                          "MemorySizeInMB": 2048,
 354                          "MaxConcurrency": 20,
 355                      }
 356                  }
 357  
 358          env: An optional dictionary of environment variables to set for the model.
 359          tags: An optional dictionary of tags to apply to the endpoint.
 360      """
 361      import boto3
 362  
 363      if (not archive) and (not synchronous):
 364          raise MlflowException(
 365              message=(
 366                  "Resources must be archived when `deploy()` is executed in non-synchronous mode."
 367                  " Either set `synchronous=True` or `archive=True`."
 368              ),
 369              error_code=INVALID_PARAMETER_VALUE,
 370          )
 371  
 372      if mode not in DEPLOYMENT_MODES:
 373          raise MlflowException(
 374              message="`mode` must be one of: {deployment_modes}".format(
 375                  deployment_modes=",".join(DEPLOYMENT_MODES)
 376              ),
 377              error_code=INVALID_PARAMETER_VALUE,
 378          )
 379  
 380      model_path = _download_artifact_from_uri(model_uri)
 381      model_config_path = os.path.join(model_path, MLMODEL_FILE_NAME)
 382      if not os.path.exists(model_config_path):
 383          raise MlflowException(
 384              message=(
 385                  f"Failed to find {MLMODEL_FILE_NAME} configuration within the specified model's "
 386                  "root directory."
 387              ),
 388              error_code=INVALID_PARAMETER_VALUE,
 389          )
 390      model_config = Model.load(model_config_path)
 391  
 392      if flavor is None:
 393          flavor = _get_preferred_deployment_flavor(model_config)
 394      else:
 395          _validate_deployment_flavor(model_config, flavor)
 396      _logger.info("Using the %s flavor for deployment!", flavor)
 397  
 398      assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn)
 399  
 400      s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials)
 401      sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials)
 402  
 403      endpoint_exists = _find_endpoint(endpoint_name=app_name, sage_client=sage_client) is not None
 404      if endpoint_exists and mode == DEPLOYMENT_MODE_CREATE:
 405          raise MlflowException(
 406              message=(
 407                  f"You are attempting to deploy an application with name: {app_name} in"
 408                  f" '{DEPLOYMENT_MODE_CREATE}' mode. However, an application with the same name"
 409                  " already exists. If you want to update this application, deploy in"
 410                  f" '{DEPLOYMENT_MODE_ADD}' or '{DEPLOYMENT_MODE_REPLACE}' mode."
 411              ),
 412              error_code=INVALID_PARAMETER_VALUE,
 413          )
 414  
 415      model_name = _get_sagemaker_model_name(endpoint_name=app_name)
 416  
 417      if not image_url:
 418          image_url = _get_default_image_url(region_name=region_name)
 419      if not execution_role_arn:
 420          execution_role_arn = _get_assumed_role_arn(**assume_role_credentials)
 421      if not bucket:
 422          _logger.info("No model data bucket specified, using the default bucket")
 423          bucket = _get_default_s3_bucket(region_name, **assume_role_credentials)
 424  
 425      model_s3_path = _upload_s3(
 426          local_model_path=model_path,
 427          bucket=bucket,
 428          prefix=model_name,
 429          region_name=region_name,
 430          s3_client=s3_client,
 431          **assume_role_credentials,
 432      )
 433  
 434      if endpoint_exists:
 435          deployment_operation = _update_sagemaker_endpoint(
 436              endpoint_name=app_name,
 437              model_name=model_name,
 438              model_s3_path=model_s3_path,
 439              model_uri=model_uri,
 440              image_url=image_url,
 441              flavor=flavor,
 442              instance_type=instance_type,
 443              instance_count=instance_count,
 444              vpc_config=vpc_config,
 445              mode=mode,
 446              role=execution_role_arn,
 447              sage_client=sage_client,
 448              s3_client=s3_client,
 449              variant_name=variant_name,
 450              async_inference_config=async_inference_config,
 451              serverless_config=serverless_config,
 452              data_capture_config=data_capture_config,
 453              env=env,
 454              tags=tags,
 455          )
 456      else:
 457          deployment_operation = _create_sagemaker_endpoint(
 458              endpoint_name=app_name,
 459              model_name=model_name,
 460              model_s3_path=model_s3_path,
 461              model_uri=model_uri,
 462              image_url=image_url,
 463              flavor=flavor,
 464              instance_type=instance_type,
 465              instance_count=instance_count,
 466              vpc_config=vpc_config,
 467              data_capture_config=data_capture_config,
 468              role=execution_role_arn,
 469              sage_client=sage_client,
 470              variant_name=variant_name,
 471              async_inference_config=async_inference_config,
 472              serverless_config=serverless_config,
 473              env=env,
 474              tags=tags,
 475          )
 476  
 477      if synchronous:
 478          _logger.info("Waiting for the deployment operation to complete...")
 479          operation_status = deployment_operation.await_completion(timeout_seconds=timeout_seconds)
 480          if operation_status.state == _SageMakerOperationStatus.STATE_SUCCEEDED:
 481              _logger.info(
 482                  'The deployment operation completed successfully with message: "%s"',
 483                  operation_status.message,
 484              )
 485          else:
 486              raise MlflowException(
 487                  "The deployment operation failed with the following error message:"
 488                  f' "{operation_status.message}"'
 489              )
 490          if not archive:
 491              deployment_operation.clean_up()
 492  
 493      return app_name, flavor
 494  
 495  
 496  def _delete(
 497      app_name,
 498      region_name="us-west-2",
 499      assume_role_arn=None,
 500      archive=False,
 501      synchronous=True,
 502      timeout_seconds=300,
 503  ):
 504      """
 505      Delete a SageMaker application.
 506  
 507      Args:
 508          app_name: Name of the deployed application.
 509          region_name: Name of the AWS region in which the application is deployed.
 510          assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker
 511              to another AWS account. If unspecified, SageMaker will be deployed to
 512              the the currently active AWS account.
 513          archive: If ``True``, resources associated with the specified application, such
 514              as its associated models and endpoint configuration, are preserved.
 515              If ``False``, these resources are deleted. In order to use
 516              ``archive=False``, ``delete()`` must be executed synchronously with
 517              ``synchronous=True``.
 518          synchronous: If `True`, this function blocks until the deletion process succeeds
 519              or encounters an irrecoverable failure. If `False`, this function
 520              returns immediately after starting the deletion process. It will not wait
 521              for the deletion process to complete; in this case, the caller is
 522              responsible for monitoring the status of the deletion process via native
 523              SageMaker APIs or the AWS console.
 524          timeout_seconds: If `synchronous` is `True`, the deletion process returns after the
 525              specified number of seconds if no definitive result (success or failure)
 526              is achieved. Once the function returns, the caller is responsible
 527              for monitoring the status of the deletion process via native SageMaker
 528              APIs or the AWS console. If `synchronous` is False, this parameter
 529              is ignored.
 530      """
 531      import boto3
 532  
 533      if (not archive) and (not synchronous):
 534          raise MlflowException(
 535              message=(
 536                  "Resources must be archived when `delete()` is executed in non-synchronous mode."
 537                  " Either set `synchronous=True` or `archive=True`."
 538              ),
 539              error_code=INVALID_PARAMETER_VALUE,
 540          )
 541  
 542      assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn)
 543  
 544      s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials)
 545      sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials)
 546  
 547      endpoint_info = sage_client.describe_endpoint(EndpointName=app_name)
 548      endpoint_arn = endpoint_info["EndpointArn"]
 549  
 550      sage_client.delete_endpoint(EndpointName=app_name)
 551      _logger.info("Deleted endpoint with arn: %s", endpoint_arn)
 552  
 553      def status_check_fn():
 554          endpoint_info = _find_endpoint(endpoint_name=app_name, sage_client=sage_client)
 555          if endpoint_info is not None:
 556              return _SageMakerOperationStatus.in_progress(
 557                  "Deletion is still in progress. Current endpoint status: {endpoint_status}".format(
 558                      endpoint_status=endpoint_info["EndpointStatus"]
 559                  )
 560              )
 561          else:
 562              return _SageMakerOperationStatus.succeeded(
 563                  "The SageMaker endpoint was deleted successfully."
 564              )
 565  
 566      def cleanup_fn():
 567          _logger.info("Cleaning up unused resources...")
 568          config_name = endpoint_info["EndpointConfigName"]
 569          config_info = sage_client.describe_endpoint_config(EndpointConfigName=config_name)
 570          config_arn = config_info["EndpointConfigArn"]
 571          sage_client.delete_endpoint_config(EndpointConfigName=config_name)
 572          _logger.info("Deleted associated endpoint configuration with arn: %s", config_arn)
 573          for pv in config_info["ProductionVariants"]:
 574              model_name = pv["ModelName"]
 575              model_arn = _delete_sagemaker_model(model_name, sage_client, s3_client)
 576              _logger.info("Deleted associated model with arn: %s", model_arn)
 577  
 578      delete_operation = _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn)
 579  
 580      if synchronous:
 581          _logger.info("Waiting for the delete operation to complete...")
 582          operation_status = delete_operation.await_completion(timeout_seconds=timeout_seconds)
 583          if operation_status.state == _SageMakerOperationStatus.STATE_SUCCEEDED:
 584              _logger.info(
 585                  'The deletion operation completed successfully with message: "%s"',
 586                  operation_status.message,
 587              )
 588          else:
 589              raise MlflowException(
 590                  "The deletion operation failed with the following error message:"
 591                  f' "{operation_status.message}"'
 592              )
 593          if not archive:
 594              delete_operation.clean_up()
 595  
 596  
 597  def deploy_transform_job(
 598      job_name,
 599      model_uri,
 600      s3_input_data_type,
 601      s3_input_uri,
 602      content_type,
 603      s3_output_path,
 604      compression_type="None",
 605      split_type="Line",
 606      accept="text/csv",
 607      assemble_with="Line",
 608      input_filter="$",
 609      output_filter="$",
 610      join_resource="None",
 611      execution_role_arn=None,
 612      assume_role_arn=None,
 613      bucket=None,
 614      image_url=None,
 615      region_name="us-west-2",
 616      instance_type=DEFAULT_SAGEMAKER_INSTANCE_TYPE,
 617      instance_count=DEFAULT_SAGEMAKER_INSTANCE_COUNT,
 618      vpc_config=None,
 619      flavor=None,
 620      archive=False,
 621      synchronous=True,
 622      timeout_seconds=1200,
 623  ):
 624      """
 625      Deploy an MLflow model on AWS SageMaker and create the corresponding batch transform job.
 626      The currently active AWS account must have correct permissions set up.
 627  
 628      Args:
 629          job_name: Name of the deployed Sagemaker batch transform job.
 630          model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker.
 631              For example:
 632  
 633              - ``/Users/me/path/to/local/model``
 634              - ``relative/path/to/local/model``
 635              - ``s3://my_bucket/path/to/model``
 636              - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
 637              - ``models:/<model_name>/<model_version>``
 638              - ``models:/<model_name>/<stage>``
 639  
 640              For more information about supported URI schemes, see
 641              `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
 642              artifact-locations>`_.
 643  
 644          s3_input_data_type: Input data type for the transform job.
 645          s3_input_uri: S3 key name prefix or a manifest of the input data.
 646          content_type: The multipurpose internet mail extension (MIME) type of the data.
 647          s3_output_path: The S3 path to store the output results of the Sagemaker transform job.
 648          compression_type: The compression type of the transform data.
 649          split_type: The method to split the transform job's data files into smaller batches.
 650          accept: The multipurpose internet mail extension (MIME) type of the output data.
 651          assemble_with: The method to assemble the results of the transform job as
 652              a single S3 object.
 653          input_filter: A JSONPath expression used to select a portion of the input data for
 654              the transform job.
 655          output_filter: A JSONPath expression used to select a portion of the output data from
 656              the transform job.
 657          join_resource: The source of the data to join with the transformed data.
 658  
 659          execution_role_arn: The name of an IAM role granting the SageMaker service permissions to
 660              access the specified Docker image and S3 bucket containing MLflow
 661              model artifacts. If unspecified, the currently-assumed role will be
 662              used. This execution role is passed to the SageMaker service when
 663              creating a SageMaker model from the specified MLflow model. It is
 664              passed as the ``ExecutionRoleArn`` parameter of the `SageMaker
 665              CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/
 666              dg/API_CreateModel.html>`_. This role is *not* assumed for any other
 667              call. For more information about SageMaker execution roles for model
 668              creation, see
 669              https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html.
 670          assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker
 671              to another AWS account. If unspecified, SageMaker will be deployed to
 672              the the currently active AWS account.
 673          bucket: S3 bucket where model artifacts will be stored. Defaults to a
 674              SageMaker-compatible bucket name.
 675          image_url: URL of the ECR-hosted Docker image the model should be deployed into, produced
 676              by ``mlflow sagemaker build-and-push-container``. This parameter can also
 677              be specified by the environment variable ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``.
 678          region_name: Name of the AWS region to which to deploy the application.
 679          instance_type: The type of SageMaker ML instance on which to deploy the model. For a list
 680              of supported instance types, see
 681              https://aws.amazon.com/sagemaker/pricing/instance-types/.
 682          instance_count: The number of SageMaker ML instances on which to deploy the model.
 683          vpc_config: A dictionary specifying the VPC configuration to use when creating the
 684              new SageMaker model associated with this batch transform job. The acceptable
 685              values for this parameter are identical to those of the ``VpcConfig``
 686              parameter in the `SageMaker boto3 client's create_model method
 687              <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html
 688              #SageMaker.Client.create_model>`_. For more information, see
 689              https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html.
 690  
 691              .. code-block:: python
 692                  :caption: Example
 693  
 694                  import mlflow.sagemaker as mfs
 695  
 696                  vpc_config = {
 697                      "SecurityGroupIds": [
 698                          "sg-123456abc",
 699                      ],
 700                      "Subnets": [
 701                          "subnet-123456abc",
 702                      ],
 703                  }
 704                  mfs.deploy_transform_job(..., vpc_config=vpc_config)
 705  
 706          flavor: The name of the flavor of the model to use for deployment. Must be either
 707              ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS. If ``None``,
 708              a flavor is automatically selected from the model's available flavors. If the
 709              specified flavor is not present or not supported for deployment, an exception
 710              will be thrown.
 711          archive: If ``True``, resources like Sagemaker models and model artifacts in S3 are
 712              preserved after the finished batch transform job. If ``False``, these resources
 713              are deleted. In order to use ``archive=False``, ``deploy_transform_job()`` must
 714              be executed synchronously with ``synchronous=True``.
 715          synchronous: If ``True``, this function will block until the deployment process succeeds
 716              or encounters an irrecoverable failure. If ``False``, this function will
 717              return immediately after starting the deployment process. It will not wait
 718              for the deployment process to complete; in this case, the caller is
 719              responsible for monitoring the health and status of the pending deployment
 720              via native SageMaker APIs or the AWS console.
 721          timeout_seconds: If ``synchronous`` is ``True``, the deployment process will return after
 722              the specified number of seconds if no definitive result (success or
 723              failure) is achieved. Once the function returns, the caller is
 724              responsible for monitoring the health and status of the pending
 725              deployment using native SageMaker APIs or the AWS console. If
 726              ``synchronous`` is ``False``, this parameter is ignored.
 727      """
 728      import boto3
 729  
 730      if (not archive) and (not synchronous):
 731          raise MlflowException(
 732              message=(
 733                  "Resources must be archived when `deploy_transform_job()`"
 734                  " is executed in non-synchronous mode."
 735                  " Either set `synchronous=True` or `archive=True`."
 736              ),
 737              error_code=INVALID_PARAMETER_VALUE,
 738          )
 739  
 740      model_path = _download_artifact_from_uri(model_uri)
 741      model_config_path = os.path.join(model_path, MLMODEL_FILE_NAME)
 742      if not os.path.exists(model_config_path):
 743          raise MlflowException(
 744              message=(
 745                  f"Failed to find {MLMODEL_FILE_NAME} configuration within the specified model's"
 746                  " root directory."
 747              ),
 748              error_code=INVALID_PARAMETER_VALUE,
 749          )
 750      model_config = Model.load(model_config_path)
 751  
 752      if flavor is None:
 753          flavor = _get_preferred_deployment_flavor(model_config)
 754      else:
 755          _validate_deployment_flavor(model_config, flavor)
 756      _logger.info("Using the %s flavor for deployment!", flavor)
 757  
 758      assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn)
 759  
 760      s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials)
 761      sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials)
 762  
 763      transform_job_exists = (
 764          _find_transform_job(job_name=job_name, sage_client=sage_client) is not None
 765      )
 766      if transform_job_exists:
 767          raise MlflowException(
 768              message=(
 769                  f"You are attempting to deploy a batch transform job with name: {job_name}. "
 770                  "However, a batch transform job with the same name already exists."
 771              ),
 772              error_code=INVALID_PARAMETER_VALUE,
 773          )
 774  
 775      model_name = _get_sagemaker_transform_model_name(job_name=job_name)
 776      if not image_url:
 777          image_url = _get_default_image_url(region_name=region_name)
 778      if not execution_role_arn:
 779          execution_role_arn = _get_assumed_role_arn(**assume_role_credentials)
 780      if not bucket:
 781          _logger.info("No model data bucket specified, using the default bucket")
 782          bucket = _get_default_s3_bucket(region_name, **assume_role_credentials)
 783  
 784      model_s3_path = _upload_s3(
 785          local_model_path=model_path,
 786          bucket=bucket,
 787          prefix=model_name,
 788          region_name=region_name,
 789          s3_client=s3_client,
 790          **assume_role_credentials,
 791      )
 792  
 793      deployment_operation = _create_sagemaker_transform_job(
 794          job_name=job_name,
 795          model_name=model_name,
 796          model_s3_path=model_s3_path,
 797          model_uri=model_uri,
 798          image_url=image_url,
 799          flavor=flavor,
 800          vpc_config=vpc_config,
 801          role=execution_role_arn,
 802          sage_client=sage_client,
 803          s3_client=s3_client,
 804          instance_type=instance_type,
 805          instance_count=instance_count,
 806          s3_input_data_type=s3_input_data_type,
 807          s3_input_uri=s3_input_uri,
 808          content_type=content_type,
 809          compression_type=compression_type,
 810          split_type=split_type,
 811          s3_output_path=s3_output_path,
 812          accept=accept,
 813          assemble_with=assemble_with,
 814          input_filter=input_filter,
 815          output_filter=output_filter,
 816          join_resource=join_resource,
 817      )
 818  
 819      if synchronous:
 820          _logger.info("Waiting for the batch transform job to complete...")
 821          operation_status = deployment_operation.await_completion(timeout_seconds=timeout_seconds)
 822          if operation_status.state == _SageMakerOperationStatus.STATE_SUCCEEDED:
 823              _logger.info(
 824                  'The batch transform job completed successfully with message: "%s"',
 825                  operation_status.message,
 826              )
 827          else:
 828              raise MlflowException(
 829                  "The batch transform job failed with the following error message:"
 830                  f' "{operation_status.message}"'
 831              )
 832          if not archive:
 833              deployment_operation.clean_up()
 834  
 835  
 836  def terminate_transform_job(
 837      job_name,
 838      region_name="us-west-2",
 839      assume_role_arn=None,
 840      archive=False,
 841      synchronous=True,
 842      timeout_seconds=300,
 843  ):
 844      """
 845      Terminate a SageMaker batch transform job.
 846  
 847      Args:
 848          job_name: Name of the deployed Sagemaker batch transform job.
 849          region_name: Name of the AWS region in which the batch transform job is deployed.
 850          assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker
 851              to another AWS account. If unspecified, SageMaker will be deployed to
 852              the the currently active AWS account.
 853          archive: If ``True``, resources associated with the specified batch transform job,
 854              such as its associated models and model artifacts, are preserved.
 855              If ``False``, these resources are deleted. In order to use ``archive=False``,
 856              ``terminate_transform_job()`` must be executed synchronously
 857              with ``synchronous=True``.
 858          synchronous: If `True`, this function blocks until the termination process succeeds
 859              or encounters an irrecoverable failure. If `False`, this function
 860              returns immediately after starting the termination process. It will not
 861              wait for the termination process to complete; in this case, the caller is
 862              responsible for monitoring the status of the termination process via native
 863              SageMaker APIs or the AWS console.
 864          timeout_seconds: If `synchronous` is `True`, the termination process returns after the
 865              specified number of seconds if no definitive result (success or failure)
 866              is achieved. Once the function returns, the caller is responsible
 867              for monitoring the status of the termination process via native
 868              SageMaker APIs or the AWS console. If `synchronous` is False, this
 869              parameter is ignored.
 870      """
 871      import boto3
 872  
 873      if (not archive) and (not synchronous):
 874          raise MlflowException(
 875              message=(
 876                  "Resources must be archived when `terminate_transform_job()`"
 877                  " is executed in non-synchronous mode."
 878                  " Either set `synchronous=True` or `archive=True`."
 879              ),
 880              error_code=INVALID_PARAMETER_VALUE,
 881          )
 882  
 883      assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn)
 884  
 885      s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials)
 886      sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials)
 887  
 888      transform_job_info = sage_client.describe_transform_job(TransformJobName=job_name)
 889      transform_job_arn = transform_job_info["TransformJobArn"]
 890  
 891      sage_client.stop_transform_job(TransformJobName=job_name)
 892      _logger.info("Terminated batch transform job with arn: %s", transform_job_arn)
 893  
 894      def status_check_fn():
 895          transform_job_info = _find_transform_job(job_name=job_name, sage_client=sage_client)
 896  
 897          if transform_job_info["TransformJobStatus"] == "Stopping":
 898              return _SageMakerOperationStatus.in_progress(
 899                  "Termination is still in progress. Current batch transform job status: "
 900                  "{transform_job_status}".format(
 901                      transform_job_status=transform_job_info["TransformJobStatus"]
 902                  )
 903              )
 904          elif transform_job_info["TransformJobStatus"] == "Stopped":
 905              return _SageMakerOperationStatus.succeeded(
 906                  "The SageMaker batch transform job was terminated successfully."
 907              )
 908  
 909      def cleanup_fn():
 910          _logger.info("Cleaning up unused resources...")
 911          model_name = transform_job_info["ModelName"]
 912          model_arn = _delete_sagemaker_model(model_name, sage_client, s3_client)
 913          _logger.info("Deleted associated model with arn: %s", model_arn)
 914  
 915      stop_operation = _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn)
 916  
 917      if synchronous:
 918          _logger.info("Waiting for the termination operation to complete...")
 919          operation_status = stop_operation.await_completion(timeout_seconds=timeout_seconds)
 920          if operation_status.state == _SageMakerOperationStatus.STATE_SUCCEEDED:
 921              _logger.info(
 922                  'The termination operation completed successfully with message: "%s"',
 923                  operation_status.message,
 924              )
 925          else:
 926              raise MlflowException(
 927                  "The termination operation failed with the following error message:"
 928                  f' "{operation_status.message}"'
 929              )
 930          if not archive:
 931              stop_operation.clean_up()
 932  
 933  
 934  def push_model_to_sagemaker(
 935      model_name,
 936      model_uri,
 937      execution_role_arn=None,
 938      assume_role_arn=None,
 939      bucket=None,
 940      image_url=None,
 941      region_name="us-west-2",
 942      vpc_config=None,
 943      flavor=None,
 944  ):
 945      """
 946      Create a SageMaker Model from an MLflow model artifact.
 947      The currently active AWS account must have correct permissions set up.
 948  
 949      Args:
 950          model_name: Name of the Sagemaker model.
 951          model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker.
 952              For example:
 953  
 954              - ``/Users/me/path/to/local/model``
 955              - ``relative/path/to/local/model``
 956              - ``s3://my_bucket/path/to/model``
 957              - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
 958              - ``models:/<model_name>/<model_version>``
 959              - ``models:/<model_name>/<stage>``
 960  
 961              For more information about supported URI schemes, see
 962              `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
 963              artifact-locations>`_.
 964  
 965          execution_role_arn: The name of an IAM role granting the SageMaker service permissions to
 966              access the specified Docker image and S3 bucket containing MLflow
 967              model artifacts. If unspecified, the currently-assumed role will be
 968              used. This execution role is passed to the SageMaker service when
 969              creating a SageMaker model from the specified MLflow model. It is
 970              passed as the ``ExecutionRoleArn`` parameter of the `SageMaker
 971              CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/
 972              dg/API_CreateModel.html>`_. This role is *not* assumed for any other
 973              call. For more information about SageMaker execution roles for model
 974              creation, see
 975              https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html.
 976          assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker
 977              to another AWS account. If unspecified, SageMaker will be deployed to
 978              the the currently active AWS account.
 979          bucket: S3 bucket where model artifacts will be stored. Defaults to a
 980              SageMaker-compatible bucket name.
 981          image_url: URL of the ECR-hosted Docker image the model should be deployed into, produced
 982              by ``mlflow sagemaker build-and-push-container``. This parameter can also
 983              be specified by the environment variable ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``.
 984          region_name: Name of the AWS region to which to deploy the application.
 985          vpc_config: A dictionary specifying the VPC configuration to use when creating the
 986              new SageMaker model. The acceptable values for this parameter are identical
 987              to those of the ``VpcConfig`` parameter in the `SageMaker boto3 client's
 988              create_model method
 989              <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html
 990              #SageMaker.Client.create_model>`_. For more information, see
 991              https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html.
 992  
 993              .. code-block:: python
 994                  :caption: Example
 995  
 996                  import mlflow.sagemaker as mfs
 997  
 998                  vpc_config = {
 999                      "SecurityGroupIds": [
1000                          "sg-123456abc",
1001                      ],
1002                      "Subnets": [
1003                          "subnet-123456abc",
1004                      ],
1005                  }
1006                  mfs.push_model_to_sagemaker(..., vpc_config=vpc_config)
1007  
1008          flavor: The name of the flavor of the model to use for deployment. Must be either
1009              ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS. If ``None``,
1010              a flavor is automatically selected from the model's available flavors. If the
1011              specified flavor is not present or not supported for deployment, an exception
1012              will be thrown.
1013      """
1014      import boto3
1015  
1016      model_path = _download_artifact_from_uri(model_uri)
1017      model_config_path = os.path.join(model_path, MLMODEL_FILE_NAME)
1018      if not os.path.exists(model_config_path):
1019          raise MlflowException(
1020              message=(
1021                  f"Failed to find {MLMODEL_FILE_NAME} configuration within the specified model's"
1022                  " root directory."
1023              ),
1024              error_code=INVALID_PARAMETER_VALUE,
1025          )
1026      model_config = Model.load(model_config_path)
1027  
1028      if flavor is None:
1029          flavor = _get_preferred_deployment_flavor(model_config)
1030      else:
1031          _validate_deployment_flavor(model_config, flavor)
1032      _logger.info("Using the %s flavor for deployment!", flavor)
1033  
1034      assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn)
1035  
1036      s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials)
1037      sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials)
1038  
1039      if _does_model_exist(model_name=model_name, sage_client=sage_client):
1040          raise MlflowException(
1041              message=(
1042                  f"You are attempting to create a Sagemaker model with name: {model_name}. "
1043                  "However, a model with the same name already exists."
1044              ),
1045              error_code=INVALID_PARAMETER_VALUE,
1046          )
1047  
1048      if not image_url:
1049          image_url = _get_default_image_url(region_name=region_name)
1050      if not execution_role_arn:
1051          execution_role_arn = _get_assumed_role_arn(**assume_role_credentials)
1052      if not bucket:
1053          _logger.info("No model data bucket specified, using the default bucket")
1054          bucket = _get_default_s3_bucket(region_name, **assume_role_credentials)
1055  
1056      model_s3_path = _upload_s3(
1057          local_model_path=model_path,
1058          bucket=bucket,
1059          prefix=model_name,
1060          region_name=region_name,
1061          s3_client=s3_client,
1062          **assume_role_credentials,
1063      )
1064  
1065      model_response = _create_sagemaker_model(
1066          model_name=model_name,
1067          model_s3_path=model_s3_path,
1068          model_uri=model_uri,
1069          flavor=flavor,
1070          vpc_config=vpc_config,
1071          image_url=image_url,
1072          execution_role=execution_role_arn,
1073          sage_client=sage_client,
1074          env={},
1075          tags={},
1076      )
1077  
1078      _logger.info("Created Sagemaker model with arn: %s", model_response["ModelArn"])
1079  
1080  
1081  def run_local(name, model_uri, flavor=None, config=None):
1082      """
1083      Serve the model locally in a SageMaker compatible Docker container.
1084  
1085      Note that models deployed locally cannot be managed by other deployment APIs
1086      (e.g. ``update_deployment``, ``delete_deployment``, etc).
1087  
1088      Args:
1089          name: Name of the local serving application.
1090          model_uri: The location, in URI format, of the MLflow model to deploy locally.
1091                          For example:
1092  
1093                          - ``/Users/me/path/to/local/model``
1094                          - ``relative/path/to/local/model``
1095                          - ``s3://my_bucket/path/to/model``
1096                          - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
1097                          - ``models:/<model_name>/<model_version>``
1098                          - ``models:/<model_name>/<stage>``
1099  
1100                          For more information about supported URI schemes, see
1101                          `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
1102                          artifact-locations>`_.
1103          flavor: The name of the flavor of the model to use for deployment. Must be either
1104                      ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS.
1105                      If ``None``, a flavor is automatically selected from the model's available
1106                      flavors. If the specified flavor is not present or not supported for
1107                      deployment, an exception will be thrown.
1108          config: Configuration parameters. The supported parameters are:
1109  
1110                      - ``image``: The name of the Docker image to use for model serving. Defaults
1111                                      to ``"mlflow-pyfunc"``.
1112                      - ``port``: The port at which to expose the model server on the local host.
1113                                  Defaults to ``5000``.
1114  
1115      .. code-block:: python
1116          :caption: Python example
1117  
1118          from mlflow.models import build_docker
1119          from mlflow.deployments import get_deploy_client
1120  
1121          build_docker(name="mlflow-pyfunc")
1122  
1123          client = get_deploy_client("sagemaker")
1124          client.run_local(
1125              name="my-local-deployment",
1126              model_uri="/mlruns/0/abc/model",
1127              flavor="python_function",
1128              config={
1129                  "port": 5000,
1130                  "image": "mlflow-pyfunc",
1131              },
1132          )
1133  
1134      .. code-block:: bash
1135          :caption:  Command-line example
1136  
1137          mlflow models build-docker --name "mlflow-pyfunc"
1138          mlflow deployments run-local --target sagemaker \\
1139                  --name my-local-deployment \\
1140                  --model-uri "/mlruns/0/abc/model" \\
1141                  --flavor python_function \\
1142                  -C port=5000 \\
1143                  -C image="mlflow-pyfunc"
1144      """
1145      model_path = _download_artifact_from_uri(model_uri)
1146      model_config_path = os.path.join(model_path, MLMODEL_FILE_NAME)
1147      model_config = Model.load(model_config_path)
1148  
1149      if flavor is None:
1150          flavor = _get_preferred_deployment_flavor(model_config)
1151      else:
1152          _validate_deployment_flavor(model_config, flavor)
1153      _logger.info("Using the %s flavor for local serving!", flavor)
1154  
1155      image = config.get("image", DEFAULT_IMAGE_NAME)
1156      port = int(config.get("port", 5000))
1157  
1158      deployment_config = _get_deployment_config(flavor_name=flavor)
1159  
1160      _logger.info("launching docker image with path %s", model_path)
1161      cmd = ["docker", "run", "-v", f"{model_path}:/opt/ml/model/", "-p", f"{port}:8080"]
1162      for key, value in deployment_config.items():
1163          cmd += ["-e", f"{key}={value}"]
1164      cmd += ["--rm", image, "serve"]
1165      _logger.info("executing: %s", " ".join(cmd))
1166      proc = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr, text=True)
1167  
1168      def _sigterm_handler(*_):
1169          _logger.info("received termination signal => killing docker process")
1170          proc.send_signal(signal.SIGINT)
1171  
1172      signal.signal(signal.SIGTERM, _sigterm_handler)
1173      proc.wait()
1174  
1175  
1176  def target_help():
1177      """
1178      Provide help information for the SageMaker deployment client.
1179      """
1180      return """\
1181      For detailed documentation on the SageMaker deployment client, please visit
1182      https://mlflow.org/docs/latest/python_api/mlflow.sagemaker.html#mlflow.sagemaker.SageMakerDeploymentClient
1183  
1184      The target URI must follow the following formats:
1185      - sagemaker
1186      - sagemaker:/region_name
1187      - sagemaker:/region_name/assume_role_arn
1188  
1189      When the region_name or assume_role_arn are provided, they will be used as the default region
1190      and assumed role ARN when executing the commands.
1191  
1192      The `create` and `update` commands require a deployment name and a model_uri. The model flavor
1193      and deployment configuration can be optionally provided. These commands can also be executed
1194      in synchronous or asynchronous mode.
1195  
1196      The `delete` command accepts configurations to archive a model instead of deleting, execute
1197      in asynchronous mode and timeout period.
1198      """
1199  
1200  
1201  def _get_default_image_url(region_name):
1202      import boto3
1203  
1204      if env_img := MLFLOW_SAGEMAKER_DEPLOY_IMG_URL.get():
1205          return env_img
1206  
1207      ecr_client = boto3.client("ecr", region_name=region_name)
1208      repository_conf = ecr_client.describe_repositories(repositoryNames=[DEFAULT_IMAGE_NAME])[
1209          "repositories"
1210      ][0]
1211      return (repository_conf["repositoryUri"] + ":{version}").format(version=mlflow.version.VERSION)
1212  
1213  
1214  def _get_account_id(**assume_role_credentials):
1215      import boto3
1216  
1217      sess = boto3.Session()
1218      sts_client = sess.client("sts", **assume_role_credentials)
1219      identity_info = sts_client.get_caller_identity()
1220      return identity_info["Account"]
1221  
1222  
1223  def _get_assumed_role_arn(**assume_role_credentials):
1224      """
1225      Returns:
1226          ARN of the user's current IAM role.
1227      """
1228      import boto3
1229  
1230      sess = boto3.Session()
1231      sts_client = sess.client("sts", **assume_role_credentials)
1232      identity_info = sts_client.get_caller_identity()
1233      sts_arn = identity_info["Arn"]
1234      role_name = sts_arn.split("/")[1]
1235      iam_client = sess.client("iam", **assume_role_credentials)
1236      role_response = iam_client.get_role(RoleName=role_name)
1237      return role_response["Role"]["Arn"]
1238  
1239  
1240  def _assume_role_and_get_credentials(assume_role_arn=None):
1241      """
1242      Assume a new role in AWS and return the credentials for that role.
1243      When ``assume_role_arn`` is ``None`` or an empty string,
1244      this function does nothing and returns an empty dictionary.
1245  
1246      Args:
1247          assume_role_arn: Optional ARN of the role that will be assumed
1248  
1249      Returns:
1250          Dict with credentials of the assumed role
1251      """
1252      import boto3
1253  
1254      if not assume_role_arn:
1255          return {}
1256  
1257      sts_client = boto3.client("sts")
1258      sts_response = sts_client.assume_role(
1259          RoleArn=assume_role_arn, RoleSessionName="mlflow-sagemaker"
1260      )
1261  
1262      _logger.info("Assuming role %s for deployment!", assume_role_arn)
1263  
1264      return {
1265          "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"],
1266          "aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"],
1267          "aws_session_token": sts_response["Credentials"]["SessionToken"],
1268      }
1269  
1270  
1271  def _get_default_s3_bucket(region_name, **assume_role_credentials):
1272      import boto3
1273  
1274      # create bucket if it does not exist
1275      sess = boto3.Session()
1276      account_id = _get_account_id(**assume_role_credentials)
1277      bucket_name = f"{DEFAULT_BUCKET_NAME_PREFIX}-{region_name}-{account_id}"
1278      s3 = sess.client("s3", **assume_role_credentials)
1279      response = s3.list_buckets()
1280      buckets = [b["Name"] for b in response["Buckets"]]
1281      if bucket_name not in buckets:
1282          _logger.info("Default bucket `%s` not found. Creating...", bucket_name)
1283          bucket_creation_kwargs = {
1284              "ACL": "bucket-owner-full-control",
1285              "Bucket": bucket_name,
1286          }
1287          if region_name != "us-east-1":
1288              # The location constraint is required during bucket creation for all regions
1289              # outside of us-east-1. This constraint cannot be specified in us-east-1;
1290              # specifying it in this region results in a failure, so we will only
1291              # add it if we are deploying outside of us-east-1.
1292              # See https://docs.aws.amazon.com/cli/latest/reference/s3api/create-bucket.html#examples
1293              bucket_creation_kwargs["CreateBucketConfiguration"] = {
1294                  "LocationConstraint": region_name
1295              }
1296          response = s3.create_bucket(**bucket_creation_kwargs)
1297          _logger.info("Bucket creation response: %s", response)
1298      else:
1299          _logger.info("Default bucket `%s` already exists. Skipping creation.", bucket_name)
1300      return bucket_name
1301  
1302  
1303  def _make_tarfile(output_filename, source_dir):
1304      """
1305      create a tar.gz from a directory.
1306      """
1307      with tarfile.open(output_filename, "w:gz") as tar:
1308          for f in os.listdir(source_dir):
1309              tar.add(os.path.join(source_dir, f), arcname=f)
1310  
1311  
1312  def _upload_s3(local_model_path, bucket, prefix, region_name, s3_client, **assume_role_credentials):
1313      """
1314      Upload dir to S3 as .tar.gz.
1315  
1316      Args:
1317          local_model_path: Local path to a dir.
1318          bucket: S3 bucket where to store the data.
1319          prefix: Path within the bucket.
1320          region_name: The AWS region in which to upload data to S3.
1321          s3_client: A boto3 client for S3.
1322  
1323      Returns:
1324          S3 path of the uploaded artifact.
1325      """
1326      import boto3
1327  
1328      sess = boto3.Session(region_name=region_name, **assume_role_credentials)
1329      with TempDir() as tmp:
1330          model_data_file = tmp.path("model.tar.gz")
1331          _make_tarfile(model_data_file, local_model_path)
1332          with open(model_data_file, "rb") as fobj:
1333              key = os.path.join(prefix, "model.tar.gz")
1334              obj = sess.resource("s3").Bucket(bucket).Object(key)
1335              obj.upload_fileobj(fobj)
1336              response = s3_client.put_object_tagging(
1337                  Bucket=bucket, Key=key, Tagging={"TagSet": [{"Key": "SageMaker", "Value": "true"}]}
1338              )
1339              _logger.info("tag response: %s", response)
1340              return f"s3://{bucket}/{key}"
1341  
1342  
1343  def _get_deployment_config(flavor_name, env_override=None):
1344      """
1345      Returns:
1346          The deployment configuration as a dictionary
1347      """
1348      deployment_config = {
1349          MLFLOW_DEPLOYMENT_FLAVOR_NAME.name: flavor_name,
1350          SERVING_ENVIRONMENT: SAGEMAKER_SERVING_ENVIRONMENT,
1351      }
1352      if env_override:
1353          deployment_config.update(env_override)
1354  
1355      if os.environ.get("http_proxy") is not None:
1356          deployment_config.update({"http_proxy": os.environ["http_proxy"]})
1357  
1358      if os.environ.get("https_proxy") is not None:
1359          deployment_config.update({"https_proxy": os.environ["https_proxy"]})
1360  
1361      if os.environ.get("no_proxy") is not None:
1362          deployment_config.update({"no_proxy": os.environ["no_proxy"]})
1363  
1364      return deployment_config
1365  
1366  
1367  def _truncate_name(name, max_length):
1368      # NB: Sagemaker prevents the registration of models and configurations whose names
1369      # exceed 63 characters in length. For reference:
1370      # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_Model.html
1371      # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_TransformJob.html
1372      # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ModelConfiguration.html
1373      # This function middle-truncates the name provided to
1374      # ensure that the least critical name information is not lost
1375      if len(name) <= max_length:
1376          return name
1377      available_length = max_length - 3
1378      start_len = available_length // 2
1379      end_len = available_length - start_len
1380      truncated_name = f"{name[:start_len]}---{name[-end_len:]}"
1381      _logger.warning(
1382          f"Truncated name {name} to {truncated_name} to coerce total character counts to < 64"
1383      )
1384      return truncated_name
1385  
1386  
1387  def _get_unique_name(base_name, unique_suffix, unique_id_length=20):
1388      unique_id = uuid.uuid4().hex[:unique_id_length]
1389      unique_resource_string = f"{unique_suffix}{unique_id}"
1390      max_length = 63 - len(unique_resource_string)
1391      return _truncate_name(base_name, max_length) + unique_resource_string
1392  
1393  
1394  def _get_sagemaker_model_name(endpoint_name):
1395      return _get_unique_name(endpoint_name, "-model-")
1396  
1397  
1398  def _get_sagemaker_transform_model_name(job_name):
1399      return _get_unique_name(job_name, "-model-")
1400  
1401  
1402  def _get_sagemaker_config_name(endpoint_name):
1403      return _get_unique_name(endpoint_name, "-config-")
1404  
1405  
1406  def _get_sagemaker_config_tags(endpoint_name):
1407      return [{"Key": SAGEMAKER_APP_NAME_TAG_KEY, "Value": endpoint_name}]
1408  
1409  
1410  def _prepare_sagemaker_tags(
1411      config_tags: list[dict[str, str]],
1412      sagemaker_tags: dict[str, str] | None = None,
1413  ):
1414      if not sagemaker_tags:
1415          return config_tags
1416  
1417      if SAGEMAKER_APP_NAME_TAG_KEY in sagemaker_tags:
1418          raise MlflowException.invalid_parameter_value(
1419              f"Duplicate tag provided for '{SAGEMAKER_APP_NAME_TAG_KEY}'"
1420          )
1421      parsed = [{"Key": key, "Value": str(value)} for key, value in sagemaker_tags.items()]
1422  
1423      return config_tags + parsed
1424  
1425  
1426  def _create_sagemaker_transform_job(
1427      job_name,
1428      model_name,
1429      model_s3_path,
1430      model_uri,
1431      image_url,
1432      flavor,
1433      vpc_config,
1434      role,
1435      sage_client,
1436      s3_client,
1437      instance_type,
1438      instance_count,
1439      s3_input_data_type,
1440      s3_input_uri,
1441      content_type,
1442      compression_type,
1443      split_type,
1444      s3_output_path,
1445      accept,
1446      assemble_with,
1447      input_filter,
1448      output_filter,
1449      join_resource,
1450  ):
1451      """
1452      Args:
1453          job_name: Name of the deployed Sagemaker batch transform job.
1454          model_name: The name to assign the new SageMaker model that will be associated with the
1455              specified batch transform job.
1456          model_s3_path: S3 path where we stored the model artifacts.
1457          model_uri: URI of the MLflow model to associate with the specified SageMaker batch
1458              transform job.
1459          image_url: URL of the ECR-hosted docker image the model is being deployed into.
1460          flavor: The name of the flavor of the model to use for deployment.
1461          vpc_config: A dictionary specifying the VPC configuration to use when creating the
1462              new SageMaker model associated with this SageMaker batch transform job.
1463          role: SageMaker execution ARN role.
1464          sage_client: A boto3 client for SageMaker.
1465          s3_client: A boto3 client for S3.
1466          instance_type: The type of SageMaker ML instance on which to deploy the model.
1467          instance_count: The number of SageMaker ML instances on which to deploy the model.
1468          s3_input_data_type: Input data type for the transform job.
1469          s3_input_uri: S3 key name prefix or a manifest of the input data.
1470          content_type: The multipurpose internet mail extension (MIME) type of the data.
1471          compression_type: The compression type of the transform data.
1472          split_type: The method to split the transform job's data files into smaller batches.
1473          s3_output_path: The S3 path to store the output results of the Sagemaker transform job.
1474          accept: The multipurpose internet mail extension (MIME) type of the output data.
1475          assemble_with: The method to assemble the results of the transform job as a single
1476              S3 object.
1477          input_filter: A JSONPath expression used to select a portion of the input data for the
1478              transform job.
1479          output_filter: A JSONPath expression used to select a portion of the output data from
1480              the transform job.
1481          join_resource: The source of the data to join with the transformed data.
1482      """
1483      _logger.info("Creating new batch transform job with name: %s ...", job_name)
1484  
1485      model_response = _create_sagemaker_model(
1486          model_name=model_name,
1487          model_s3_path=model_s3_path,
1488          model_uri=model_uri,
1489          flavor=flavor,
1490          vpc_config=vpc_config,
1491          image_url=image_url,
1492          execution_role=role,
1493          sage_client=sage_client,
1494          env={},
1495          tags={},
1496      )
1497      _logger.info("Created model with arn: %s", model_response["ModelArn"])
1498  
1499      transform_input = {
1500          "DataSource": {"S3DataSource": {"S3DataType": s3_input_data_type, "S3Uri": s3_input_uri}},
1501          "ContentType": content_type,
1502          "CompressionType": compression_type,
1503          "SplitType": split_type,
1504      }
1505  
1506      transform_output = {
1507          "S3OutputPath": s3_output_path,
1508          "Accept": accept,
1509          "AssembleWith": assemble_with,
1510      }
1511  
1512      transform_resources = {"InstanceType": instance_type, "InstanceCount": instance_count}
1513  
1514      data_processing = {
1515          "InputFilter": input_filter,
1516          "OutputFilter": output_filter,
1517          "JoinSource": join_resource,
1518      }
1519  
1520      transform_job_response = sage_client.create_transform_job(
1521          TransformJobName=job_name,
1522          ModelName=model_name,
1523          TransformInput=transform_input,
1524          TransformOutput=transform_output,
1525          TransformResources=transform_resources,
1526          DataProcessing=data_processing,
1527          Tags=[{"Key": "model_name", "Value": model_name}],
1528      )
1529      _logger.info(
1530          "Created batch transform job with arn: %s", transform_job_response["TransformJobArn"]
1531      )
1532  
1533      def status_check_fn():
1534          transform_job_info = sage_client.describe_transform_job(TransformJobName=job_name)
1535  
1536          if transform_job_info is None:
1537              return _SageMakerOperationStatus.in_progress(
1538                  "Waiting for batch transform job to be created..."
1539              )
1540  
1541          transform_job_status = transform_job_info["TransformJobStatus"]
1542          if transform_job_status == "InProgress":
1543              return _SageMakerOperationStatus.in_progress(
1544                  'Waiting for batch transform job to reach the "Completed" state.                   '
1545                  f'  Current batch transform job status: "{transform_job_status}"'
1546              )
1547          elif transform_job_status == "Completed":
1548              return _SageMakerOperationStatus.succeeded(
1549                  "The SageMaker batch transform job was processed successfully."
1550              )
1551          else:
1552              failure_reason = transform_job_info.get(
1553                  "FailureReason",
1554                  "An unknown SageMaker failure occurred. Please see the SageMaker console logs"
1555                  " for more information.",
1556              )
1557              return _SageMakerOperationStatus.failed(failure_reason)
1558  
1559      def cleanup_fn():
1560          _logger.info("Cleaning up Sagemaker model and S3 model artifacts...")
1561          transform_job_info = sage_client.describe_transform_job(TransformJobName=job_name)
1562          model_name = transform_job_info["ModelName"]
1563          model_arn = _delete_sagemaker_model(model_name, sage_client, s3_client)
1564          _logger.info("Deleted associated model with arn: %s", model_arn)
1565  
1566      return _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn)
1567  
1568  
1569  def _create_sagemaker_endpoint(
1570      endpoint_name,
1571      model_name,
1572      model_s3_path,
1573      model_uri,
1574      image_url,
1575      flavor,
1576      instance_type,
1577      vpc_config,
1578      data_capture_config,
1579      instance_count,
1580      role,
1581      sage_client,
1582      variant_name=None,
1583      async_inference_config=None,
1584      serverless_config=None,
1585      env=None,
1586      tags=None,
1587  ):
1588      """
1589      Args:
1590          endpoint_name: The name of the SageMaker endpoint to create.
1591          model_name: The name to assign the new SageMaker model that will be associated with the
1592              specified endpoint.
1593          model_s3_path: S3 path where we stored the model artifacts.
1594          model_uri: URI of the MLflow model to associate with the specified SageMaker endpoint.
1595          image_url: URL of the ECR-hosted docker image the model is being deployed into.
1596          flavor: The name of the flavor of the model to use for deployment.
1597          instance_type: The type of SageMaker ML instance on which to deploy the model.
1598          instance_count: The number of SageMaker ML instances on which to deploy the model.
1599          vpc_config: A dictionary specifying the VPC configuration to use when creating the
1600              new SageMaker model associated with this SageMaker endpoint.
1601          data_capture_config: A dictionary specifying the data capture configuration to use when
1602              creating the new SageMaker model associated with this application.
1603          role: SageMaker execution ARN role.
1604          sage_client: A boto3 client for SageMaker.
1605          variant_name: The name to assign to the new production variant.
1606          env: A dictionary of environment variables to set for the model.
1607          tags: A dictionary of tags to apply to the endpoint.
1608      """
1609      _logger.info("Creating new endpoint with name: %s ...", endpoint_name)
1610  
1611      model_response = _create_sagemaker_model(
1612          model_name=model_name,
1613          model_s3_path=model_s3_path,
1614          model_uri=model_uri,
1615          flavor=flavor,
1616          vpc_config=vpc_config,
1617          image_url=image_url,
1618          execution_role=role,
1619          sage_client=sage_client,
1620          env=env or {},
1621          tags=tags or {},
1622      )
1623      _logger.info("Created model with arn: %s", model_response["ModelArn"])
1624  
1625      if not variant_name:
1626          variant_name = model_name
1627  
1628      production_variant = {
1629          "VariantName": variant_name,
1630          "ModelName": model_name,
1631          "InitialVariantWeight": 1,
1632      }
1633  
1634      if serverless_config:
1635          production_variant["ServerlessConfig"] = serverless_config
1636      else:
1637          production_variant["InstanceType"] = instance_type
1638          production_variant["InitialInstanceCount"] = instance_count
1639  
1640      config_name = _get_sagemaker_config_name(endpoint_name)
1641      config_tags = _get_sagemaker_config_tags(endpoint_name)
1642      tags_list = _prepare_sagemaker_tags(config_tags, tags)
1643      endpoint_config_kwargs = {
1644          "EndpointConfigName": config_name,
1645          "ProductionVariants": [production_variant],
1646          "Tags": config_tags,
1647      }
1648      if async_inference_config:
1649          endpoint_config_kwargs["AsyncInferenceConfig"] = async_inference_config
1650      if data_capture_config is not None:
1651          endpoint_config_kwargs["DataCaptureConfig"] = data_capture_config
1652      endpoint_config_response = sage_client.create_endpoint_config(**endpoint_config_kwargs)
1653      _logger.info(
1654          "Created endpoint configuration with arn: %s", endpoint_config_response["EndpointConfigArn"]
1655      )
1656  
1657      endpoint_response = sage_client.create_endpoint(
1658          EndpointName=endpoint_name,
1659          EndpointConfigName=config_name,
1660          Tags=tags_list or [],
1661      )
1662      _logger.info("Created endpoint with arn: %s", endpoint_response["EndpointArn"])
1663  
1664      def status_check_fn():
1665          endpoint_info = _find_endpoint(endpoint_name=endpoint_name, sage_client=sage_client)
1666  
1667          if endpoint_info is None:
1668              return _SageMakerOperationStatus.in_progress("Waiting for endpoint to be created...")
1669  
1670          endpoint_status = endpoint_info["EndpointStatus"]
1671          if endpoint_status == "Creating":
1672              return _SageMakerOperationStatus.in_progress(
1673                  'Waiting for endpoint to reach the "InService" state. Current endpoint status:'
1674                  f' "{endpoint_status}"'
1675              )
1676          elif endpoint_status == "InService":
1677              return _SageMakerOperationStatus.succeeded(
1678                  "The SageMaker endpoint was created successfully."
1679              )
1680          else:
1681              failure_reason = endpoint_info.get(
1682                  "FailureReason",
1683                  "An unknown SageMaker failure occurred. Please see the SageMaker console logs"
1684                  " for more information.",
1685              )
1686              return _SageMakerOperationStatus.failed(failure_reason)
1687  
1688      def cleanup_fn():
1689          pass
1690  
1691      return _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn)
1692  
1693  
1694  def _update_sagemaker_endpoint(
1695      endpoint_name,
1696      model_name,
1697      model_uri,
1698      image_url,
1699      model_s3_path,
1700      flavor,
1701      instance_type,
1702      instance_count,
1703      vpc_config,
1704      mode,
1705      role,
1706      sage_client,
1707      s3_client,
1708      variant_name=None,
1709      async_inference_config=None,
1710      serverless_config=None,
1711      data_capture_config=None,
1712      env=None,
1713      tags=None,
1714  ):
1715      """
1716      Args:
1717          endpoint_name: The name of the SageMaker endpoint to update.
1718          model_name: The name to assign the new SageMaker model that will be associated with the
1719              specified endpoint.
1720          model_uri: URI of the MLflow model to associate with the specified SageMaker endpoint.
1721          image_url: URL of the ECR-hosted Docker image the model is being deployed into
1722          model_s3_path: S3 path where we stored the model artifacts
1723          flavor: The name of the flavor of the model to use for deployment.
1724          instance_type: The type of SageMaker ML instance on which to deploy the model.
1725          instance_count: The number of SageMaker ML instances on which to deploy the model.
1726          vpc_config: A dictionary specifying the VPC configuration to use when creating the
1727              new SageMaker model associated with this SageMaker endpoint.
1728          mode: either mlflow.sagemaker.DEPLOYMENT_MODE_ADD or
1729              mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE.
1730          role: SageMaker execution ARN role.
1731          sage_client: A boto3 client for SageMaker.
1732          s3_client: A boto3 client for S3.
1733          variant_name: The name to assign to the new production variant if it doesn't already exist.
1734          async_inference_config: A dictionary specifying the async inference configuration to use.
1735              For more information, see https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AsyncInferenceConfig.html.
1736              Defaults to ``None``.
1737          data_capture_config: A dictionary specifying the data capture configuration to use.
1738              For more information, see https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html.
1739              Defaults to ``None``.
1740          env: A dictionary of environment variables to set for the model.
1741          tags: A dictionary of tags to apply to the endpoint configuration.
1742      """
1743      if mode not in [DEPLOYMENT_MODE_ADD, DEPLOYMENT_MODE_REPLACE]:
1744          msg = f"Invalid mode `{mode}` for deployment to a pre-existing application"
1745          raise ValueError(msg)
1746  
1747      endpoint_info = sage_client.describe_endpoint(EndpointName=endpoint_name)
1748      endpoint_arn = endpoint_info["EndpointArn"]
1749      deployed_config_name = endpoint_info["EndpointConfigName"]
1750      deployed_config_info = sage_client.describe_endpoint_config(
1751          EndpointConfigName=deployed_config_name
1752      )
1753      deployed_config_arn = deployed_config_info["EndpointConfigArn"]
1754      deployed_production_variants = deployed_config_info["ProductionVariants"]
1755  
1756      _logger.info("Found active endpoint with arn: %s. Updating...", endpoint_arn)
1757  
1758      new_model_response = _create_sagemaker_model(
1759          model_name=model_name,
1760          model_s3_path=model_s3_path,
1761          model_uri=model_uri,
1762          flavor=flavor,
1763          vpc_config=vpc_config,
1764          image_url=image_url,
1765          execution_role=role,
1766          sage_client=sage_client,
1767          env=env or {},
1768          tags=tags or {},
1769      )
1770      _logger.info("Created new model with arn: %s", new_model_response["ModelArn"])
1771  
1772      if not variant_name:
1773          variant_name = model_name
1774  
1775      if mode == DEPLOYMENT_MODE_ADD:
1776          new_model_weight = 0
1777          production_variants = deployed_production_variants
1778      elif mode == DEPLOYMENT_MODE_REPLACE:
1779          new_model_weight = 1
1780          production_variants = []
1781  
1782      new_production_variant = {
1783          "VariantName": variant_name,
1784          "ModelName": model_name,
1785          "InitialVariantWeight": new_model_weight,
1786      }
1787  
1788      if serverless_config:
1789          new_production_variant["ServerlessConfig"] = serverless_config
1790      else:
1791          new_production_variant["InstanceType"] = instance_type
1792          new_production_variant["InitialInstanceCount"] = instance_count
1793  
1794      production_variants.append(new_production_variant)
1795  
1796      # Create the new endpoint configuration and update the endpoint
1797      # to adopt the new configuration
1798      new_config_name = _get_sagemaker_config_name(endpoint_name)
1799      config_tags = _get_sagemaker_config_tags(endpoint_name)
1800      endpoint_config_kwargs = {
1801          "EndpointConfigName": new_config_name,
1802          "ProductionVariants": production_variants,
1803          "Tags": config_tags,
1804      }
1805      if async_inference_config:
1806          endpoint_config_kwargs["AsyncInferenceConfig"] = async_inference_config
1807      if data_capture_config is not None:
1808          endpoint_config_kwargs["DataCaptureConfig"] = data_capture_config
1809      endpoint_config_response = sage_client.create_endpoint_config(**endpoint_config_kwargs)
1810      _logger.info(
1811          "Created new endpoint configuration with arn: %s",
1812          endpoint_config_response["EndpointConfigArn"],
1813      )
1814  
1815      sage_client.update_endpoint(EndpointName=endpoint_name, EndpointConfigName=new_config_name)
1816      _logger.info("Updated endpoint with new configuration!")
1817  
1818      operation_start_time = time.time()
1819  
1820      def status_check_fn():
1821          if time.time() - operation_start_time < 20:
1822              # Wait at least 20 seconds before checking the status of the update; this ensures
1823              # that we don't consider the operation to have failed if small delays occur at
1824              # initialization time
1825              return _SageMakerOperationStatus.in_progress()
1826  
1827          endpoint_info = sage_client.describe_endpoint(EndpointName=endpoint_name)
1828          endpoint_update_was_rolled_back = (
1829              endpoint_info["EndpointStatus"] == "InService"
1830              and endpoint_info["EndpointConfigName"] != new_config_name
1831          )
1832          if endpoint_update_was_rolled_back or endpoint_info["EndpointStatus"] == "Failed":
1833              failure_reason = endpoint_info.get(
1834                  "FailureReason",
1835                  "An unknown SageMaker failure occurred."
1836                  " Please see the SageMaker console logs for"
1837                  " more information.",
1838              )
1839              return _SageMakerOperationStatus.failed(failure_reason)
1840          elif endpoint_info["EndpointStatus"] == "InService":
1841              return _SageMakerOperationStatus.succeeded(
1842                  "The SageMaker endpoint was updated successfully."
1843              )
1844          else:
1845              return _SageMakerOperationStatus.in_progress(
1846                  "The update operation is still in progress. Current endpoint status:"
1847                  ' "{endpoint_status}"'.format(endpoint_status=endpoint_info["EndpointStatus"])
1848              )
1849  
1850      def cleanup_fn():
1851          _logger.info("Cleaning up unused resources...")
1852          if mode == DEPLOYMENT_MODE_REPLACE:
1853              for pv in deployed_production_variants:
1854                  deployed_model_arn = _delete_sagemaker_model(
1855                      model_name=pv["ModelName"], sage_client=sage_client, s3_client=s3_client
1856                  )
1857                  _logger.info("Deleted model with arn: %s", deployed_model_arn)
1858  
1859          sage_client.delete_endpoint_config(EndpointConfigName=deployed_config_name)
1860          _logger.info("Deleted endpoint configuration with arn: %s", deployed_config_arn)
1861  
1862      return _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn)
1863  
1864  
1865  def _create_sagemaker_model(
1866      model_name,
1867      model_s3_path,
1868      model_uri,
1869      flavor,
1870      vpc_config,
1871      image_url,
1872      execution_role,
1873      sage_client,
1874      env,
1875      tags,
1876  ):
1877      """
1878      Args:
1879          model_name: The name to assign the new SageMaker model that is created.
1880          model_s3_path: S3 path where the model artifacts are stored.
1881          model_uri: URI of the MLflow model associated with the new SageMaker model.
1882          flavor: The name of the flavor of the model.
1883          vpc_config: A dictionary specifying the VPC configuration to use when creating the
1884              new SageMaker model associated with this SageMaker endpoint.
1885          image_url: URL of the ECR-hosted Docker image that will serve as the
1886              model's container,
1887          execution_role: The ARN of the role that SageMaker will assume when creating the model.
1888          sage_client: A boto3 client for SageMaker.
1889          env: A dictionary of environment variables to set for the model.
1890          tags: A dictionary of tags to apply to the SageMaker model.
1891  
1892      Returns:
1893          AWS response containing metadata associated with the new model.
1894      """
1895      tags["model_uri"] = str(model_uri)
1896      create_model_args = {
1897          "ModelName": model_name,
1898          "PrimaryContainer": {
1899              "Image": image_url,
1900              "ModelDataUrl": model_s3_path,
1901              "Environment": _get_deployment_config(flavor_name=flavor, env_override=env),
1902          },
1903          "ExecutionRoleArn": execution_role,
1904          "Tags": [{"Key": key, "Value": str(value)} for key, value in tags.items()],
1905      }
1906      if vpc_config is not None:
1907          create_model_args["VpcConfig"] = vpc_config
1908  
1909      return sage_client.create_model(**create_model_args)
1910  
1911  
1912  def _delete_sagemaker_model(model_name, sage_client, s3_client):
1913      """
1914      Args:
1915          sage_client: A boto3 client for SageMaker.
1916          s3_client: A boto3 client for S3.
1917  
1918      Returns:
1919          ARN of the deleted model.
1920      """
1921      model_info = sage_client.describe_model(ModelName=model_name)
1922      model_arn = model_info["ModelArn"]
1923      model_data_url = model_info["PrimaryContainer"]["ModelDataUrl"]
1924  
1925      # Parse the model data url to obtain a bucket path. The following
1926      # procedure is safe due to the well-documented structure of the `ModelDataUrl`
1927      # (see https://docs.aws.amazon.com/sagemaker/latest/dg/API_ContainerDefinition.html)
1928      parsed_data_url = urllib.parse.urlparse(model_data_url)
1929      bucket_name = parsed_data_url.netloc
1930      bucket_key = parsed_data_url.path.lstrip("/")
1931  
1932      s3_client.delete_object(Bucket=bucket_name, Key=bucket_key)
1933      sage_client.delete_model(ModelName=model_name)
1934  
1935      return model_arn
1936  
1937  
1938  def _delete_sagemaker_endpoint_configuration(endpoint_config_name, sage_client):
1939      """
1940      Args:
1941          sage_client: A boto3 client for SageMaker.
1942  
1943      Returns:
1944          ARN of the deleted endpoint configuration.
1945      """
1946      endpoint_config_info = sage_client.describe_endpoint_config(
1947          EndpointConfigName=endpoint_config_name
1948      )
1949      sage_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
1950      return endpoint_config_info["EndpointConfigArn"]
1951  
1952  
1953  def _find_endpoint(endpoint_name, sage_client):
1954      """
1955      Finds a SageMaker endpoint with the specified name in the caller's AWS account, returning a
1956      NoneType if the endpoint is not found.
1957  
1958      Args:
1959          sage_client: A boto3 client for SageMaker.
1960  
1961      Returns:
1962          If the endpoint exists, a dictionary of endpoint attributes. If the endpoint does not
1963          exist, ``None``.
1964      """
1965      endpoints_page = sage_client.list_endpoints(MaxResults=100, NameContains=endpoint_name)
1966  
1967      while True:
1968          for endpoint in endpoints_page["Endpoints"]:
1969              if endpoint["EndpointName"] == endpoint_name:
1970                  return endpoint
1971  
1972          if "NextToken" in endpoints_page:
1973              endpoints_page = sage_client.list_endpoints(
1974                  MaxResults=100, NextToken=endpoints_page["NextToken"], NameContains=endpoint_name
1975              )
1976          else:
1977              return None
1978  
1979  
1980  def _find_transform_job(job_name, sage_client):
1981      """
1982      Finds a SageMaker batch transform job with the specified name in the caller's AWS account,
1983      returning a NoneType if the transform job is not found.
1984  
1985      Args:
1986          sage_client: A boto3 client for SageMaker.
1987  
1988      Returns:
1989          If the transform job exists, a dictionary of transform job attributes. If the
1990          transform job does not exist, ``None``.
1991      """
1992      transform_jobs_page = sage_client.list_transform_jobs(MaxResults=100, NameContains=job_name)
1993  
1994      while True:
1995          for transform_job in transform_jobs_page["TransformJobSummaries"]:
1996              if transform_job["TransformJobName"] == job_name:
1997                  return transform_job
1998  
1999          if "NextToken" in transform_jobs_page:
2000              transform_jobs_page = sage_client.list_transform_jobs(
2001                  MaxResults=100,
2002                  NextToken=transform_jobs_page["NextToken"],
2003                  NameContains=job_name,
2004              )
2005          else:
2006              return None
2007  
2008  
2009  def _does_model_exist(model_name, sage_client):
2010      """
2011      Determines whether a SageMaker model exists with the specified name in the caller's AWS account,
2012      returning True if the model exists, returning False if the model does not exist.
2013  
2014      Args:
2015          sage_client: A boto3 client for SageMaker.
2016  
2017      Returns:
2018          If the model exists, ``True``. If the model does not
2019          exist, ``False``.
2020      """
2021      try:
2022          response = sage_client.describe_model(ModelName=model_name)
2023      except sage_client.exceptions.ClientError as error:
2024          if "Could not find model" in error.response["Error"]["Message"]:
2025              return False
2026      else:
2027          return bool(response)
2028  
2029  
2030  class SageMakerDeploymentClient(BaseDeploymentClient):
2031      """
2032      Initialize a deployment client for SageMaker. The default region and assumed role ARN will
2033      be set according to the value of the `target_uri`.
2034  
2035      This class is meant to supersede the other ``mlflow.sagemaker`` real-time serving API's.
2036      It is also designed to be used through the :py:mod:`mlflow.deployments` module.
2037      This means that you can deploy to SageMaker using the
2038      `mlflow deployments CLI <https://www.mlflow.org/docs/latest/cli.html#mlflow-deployments>`_ and
2039      get a client through the :py:mod:`mlflow.deployments.get_deploy_client` function.
2040  
2041      Args:
2042          target_uri: A URI that follows one of the following formats:
2043  
2044              - ``sagemaker``: This will set the default region to `us-west-2` and
2045                the default assumed role ARN to `None`.
2046              - ``sagemaker:/region_name``: This will set the default region to
2047                `region_name` and the default assumed role ARN to `None`.
2048              - ``sagemaker:/region_name/assumed_role_arn``: This will set the default
2049                region to `region_name` and the default assumed role ARN to
2050                `assumed_role_arn`.
2051  
2052              When an `assumed_role_arn` is provided without a `region_name`,
2053              an MlflowException will be raised.
2054      """
2055  
2056      def __init__(self, target_uri):
2057          super().__init__(target_uri=target_uri)
2058  
2059          # Default region_name and assumed_role_arn when
2060          # the target_uri is `sagemaker` or `sagemaker:/`
2061          self.region_name = DEFAULT_REGION_NAME
2062          self.assumed_role_arn = None
2063          self._get_values_from_target_uri()
2064  
2065      def _get_values_from_target_uri(self):
2066          parsed = urllib.parse.urlparse(self.target_uri)
2067          values_str = parsed.path.strip("/")
2068  
2069          if not parsed.scheme or not values_str:
2070              return
2071  
2072          separator_index = values_str.find("/")
2073          if separator_index == -1:
2074              # values_str would look like us-east-1
2075              self.region_name = values_str
2076          else:
2077              # values_str could look like us-east-1/arn:aws:1234:role/assumed_role
2078              self.region_name = values_str[:separator_index]
2079              self.assumed_role_arn = values_str[separator_index + 1 :]
2080  
2081              # if values_str contains multiple interior slashes such as
2082              # us-east-1/////arn:aws:1234:role/assumed_role, remove
2083              # the extra slashes that come before "arn"
2084              self.assumed_role_arn = self.assumed_role_arn.strip("/")
2085  
2086          if self.region_name.startswith("arn"):
2087              raise MlflowException(
2088                  message=(
2089                      "It looks like the target_uri contains an IAM role ARN without a region name.\n"
2090                      "A region name must be provided when the target_uri contains a role ARN.\n"
2091                      "In this case, the target_uri must follow the format: "
2092                      "sagemaker:/region_name/assumed_role_arn.\n"
2093                      f"The provided target_uri is: {self.target_uri}\n"
2094                  ),
2095                  error_code=INVALID_PARAMETER_VALUE,
2096              )
2097  
2098      def _default_deployment_config(self, create_mode=True):
2099          config = {
2100              "assume_role_arn": self.assumed_role_arn,
2101              "execution_role_arn": None,
2102              "bucket": None,
2103              "image_url": None,
2104              "region_name": self.region_name,
2105              "archive": False,
2106              "instance_type": DEFAULT_SAGEMAKER_INSTANCE_TYPE,
2107              "instance_count": DEFAULT_SAGEMAKER_INSTANCE_COUNT,
2108              "vpc_config": None,
2109              "data_capture_config": None,
2110              "synchronous": True,
2111              "timeout_seconds": 1200,
2112              "variant_name": None,
2113              "env": None,
2114              "tags": None,
2115              "async_inference_config": {},
2116              "serverless_config": {},
2117          }
2118  
2119          if create_mode:
2120              config["mode"] = DEPLOYMENT_MODE_CREATE
2121          else:
2122              config["mode"] = DEPLOYMENT_MODE_REPLACE
2123  
2124          return config
2125  
2126      def _apply_custom_config(self, config, custom_config):
2127          int_fields = {"instance_count", "timeout_seconds"}
2128          bool_fields = {"synchronous", "archive"}
2129          dict_fields = {
2130              "vpc_config",
2131              "data_capture_config",
2132              "tags",
2133              "env",
2134              "async_inference_config",
2135              "serverless_config",
2136          }
2137          for key, value in custom_config.items():
2138              if key not in config:
2139                  continue
2140  
2141              if key in int_fields and not isinstance(value, int):
2142                  value = int(value)
2143              elif key in bool_fields and not isinstance(value, bool):
2144                  value = value == "True"
2145              elif key in dict_fields and not isinstance(value, dict):
2146                  value = json.loads(value)
2147  
2148              config[key] = value
2149  
2150      def create_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None):
2151          """
2152          Deploy an MLflow model on AWS SageMaker.
2153          The currently active AWS account must have correct permissions set up.
2154  
2155          This function creates a SageMaker endpoint. For more information about the input data
2156          formats accepted by this endpoint, see the
2157          `MLflow deployment tools documentation <../../deployment/deploy-model-to-sagemaker.html>`_.
2158  
2159          Args:
2160              name: Name of the deployed application.
2161              model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker.
2162                  For example:
2163  
2164                  - ``/Users/me/path/to/local/model``
2165                  - ``relative/path/to/local/model``
2166                  - ``s3://my_bucket/path/to/model``
2167                  - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
2168                  - ``models:/<model_name>/<model_version>``
2169                  - ``models:/<model_name>/<stage>``
2170  
2171                  For more information about supported URI schemes, see
2172                  `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
2173                  artifact-locations>`_.
2174              flavor: The name of the flavor of the model to use for deployment. Must be either
2175                  ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS.
2176                  If ``None``, a flavor is automatically selected from the model's available
2177                  flavors. If the specified flavor is not present or not supported for
2178                  deployment, an exception will be thrown.
2179              config: Configuration parameters. The supported parameters are:
2180  
2181                  - ``assume_role_arn``: The name of an IAM cross-account role to be assumed
2182                    to deploy SageMaker to another AWS account. If this parameter is not
2183                    specified, the role given in the ``target_uri`` will be used. If the
2184                    role is not given in the ``target_uri``, defaults to ``us-west-2``.
2185  
2186                  - ``execution_role_arn``: The name of an IAM role granting the SageMaker
2187                    service permissions to access the specified Docker image and S3 bucket
2188                    containing MLflow model artifacts. If unspecified, the currently-assumed
2189                    role will be used. This execution role is passed to the SageMaker service
2190                    when creating a SageMaker model from the specified MLflow model. It is
2191                    passed as the ``ExecutionRoleArn`` parameter of the `SageMaker
2192                    CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/
2193                    dg/API_CreateModel.html>`_. This role is *not* assumed for any other
2194                    call. For more information about SageMaker execution roles for model
2195                    creation, see
2196                    https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html.
2197  
2198                  - ``bucket``: S3 bucket where model artifacts will be stored. Defaults to a
2199                    SageMaker-compatible bucket name.
2200  
2201                  - ``image_url``: URL of the ECR-hosted Docker image the model should be
2202                    deployed into, produced by ``mlflow sagemaker build-and-push-container``.
2203                    This parameter can also be specified by the environment variable
2204                    ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``.
2205  
2206                  - ``region_name``: Name of the AWS region to which to deploy the application.
2207                    If unspecified, use the region name given in the ``target_uri``.
2208                    If it is also not specified in the ``target_uri``,
2209                    defaults to ``us-west-2``.
2210  
2211                  - ``archive``: If ``True``, any pre-existing SageMaker application resources
2212                    that become inactive (i.e. as a result of deploying in
2213                    ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE`` mode) are preserved.
2214                    These resources may include unused SageMaker models and endpoint
2215                    configurations that were associated with a prior version of the
2216                    application endpoint. If ``False``, these resources are deleted.
2217                    In order to use ``archive=False``, ``create_deployment()`` must be executed
2218                    synchronously with ``synchronous=True``. Defaults to ``False``.
2219  
2220                  - ``instance_type``: The type of SageMaker ML instance on which to deploy the
2221                    model. For a list of supported instance types, see
2222                    https://aws.amazon.com/sagemaker/pricing/instance-types/.
2223                    Defaults to ``ml.m4.xlarge``.
2224  
2225                  - ``instance_count``: The number of SageMaker ML instances on which to deploy
2226                    the model. Defaults to ``1``.
2227  
2228                  - ``synchronous``: If ``True``, this function will block until the deployment
2229                    process succeeds or encounters an irrecoverable failure. If ``False``,
2230                    this function will return immediately after starting the deployment
2231                    process. It will not wait for the deployment process to complete;
2232                    in this case, the caller is responsible for monitoring the health and
2233                    status of the pending deployment via native SageMaker APIs or the AWS
2234                    console. Defaults to ``True``.
2235  
2236                  - ``timeout_seconds``: If ``synchronous`` is ``True``, the deployment process
2237                    will return after the specified number of seconds if no definitive result
2238                    (success or failure) is achieved. Once the function returns, the caller is
2239                    responsible for monitoring the health and status of the pending
2240                    deployment using native SageMaker APIs or the AWS console. If
2241                    ``synchronous`` is ``False``, this parameter is ignored.
2242                    Defaults to ``300``.
2243  
2244                  - ``vpc_config``: A dictionary specifying the VPC configuration to use when
2245                    creating the new SageMaker model associated with this application.
2246                    The acceptable values for this parameter are identical to those of the
2247                    ``VpcConfig`` parameter in the `SageMaker boto3 client's create_model
2248                    method <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html
2249                    #SageMaker.Client.create_model>`_. For more information, see
2250                    https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html.
2251                    Defaults to ``None``.
2252  
2253                  - ``data_capture_config``: A dictionary specifying the data capture
2254                    configuration to use when creating the new SageMaker model associated with
2255                    this application.
2256                    For more information, see
2257                    https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html.
2258                    Defaults to ``None``.
2259  
2260                  - ``variant_name``: A string specifying the desired name when creating a production
2261                    variant.  Defaults to ``None``.
2262  
2263                  - ``async_inference_config``: A dictionary specifying the
2264                    async_inference_configuration
2265  
2266                  - ``serverless_config``: A dictionary specifying the serverless_configuration
2267  
2268                  - ``env``: A dictionary specifying environment variables as key-value
2269                    pairs to be set for the deployed model. Defaults to ``None``.
2270  
2271                  - ``tags``: A dictionary of key-value pairs representing additional
2272                    tags to be set for the deployed model. Defaults to ``None``.
2273  
2274              endpoint: (optional) Endpoint to create the deployment under. Currently unsupported
2275  
2276          .. code-block:: python
2277              :caption: Python example
2278  
2279              from mlflow.deployments import get_deploy_client
2280  
2281              vpc_config = {
2282                  "SecurityGroupIds": [
2283                      "sg-123456abc",
2284                  ],
2285                  "Subnets": [
2286                      "subnet-123456abc",
2287                  ],
2288              }
2289              config = dict(
2290                  assume_role_arn="arn:aws:123:role/assumed_role",
2291                  execution_role_arn="arn:aws:456:role/execution_role",
2292                  bucket_name="my-s3-bucket",
2293                  image_url="1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1",
2294                  region_name="us-east-1",
2295                  archive=False,
2296                  instance_type="ml.m5.4xlarge",
2297                  instance_count=1,
2298                  synchronous=True,
2299                  timeout_seconds=300,
2300                  vpc_config=vpc_config,
2301                  variant_name="prod-variant-1",
2302                  env={"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": '"--timeout 60"'},
2303                  tags={"training_timestamp": "2022-11-01T05:12:26"},
2304              )
2305              client = get_deploy_client("sagemaker")
2306              client.create_deployment(
2307                  "my-deployment",
2308                  model_uri="/mlruns/0/abc/model",
2309                  flavor="python_function",
2310                  config=config,
2311              )
2312          .. code-block:: bash
2313              :caption:  Command-line example
2314  
2315              mlflow deployments create --target sagemaker:/us-east-1/arn:aws:123:role/assumed_role \\
2316                      --name my-deployment \\
2317                      --model-uri /mlruns/0/abc/model \\
2318                      --flavor python_function\\
2319                      -C execution_role_arn=arn:aws:456:role/execution_role \\
2320                      -C bucket_name=my-s3-bucket \\
2321                      -C image_url=1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1 \\
2322                      -C region_name=us-east-1 \\
2323                      -C archive=False \\
2324                      -C instance_type=ml.m5.4xlarge \\
2325                      -C instance_count=1 \\
2326                      -C synchronous=True \\
2327                      -C timeout_seconds=300 \\
2328                      -C variant_name=prod-variant-1 \\
2329                      -C vpc_config='{"SecurityGroupIds": ["sg-123456abc"], \\
2330                      "Subnets": ["subnet-123456abc"]}' \\
2331                      -C data_capture_config='{"EnableCapture": True, \\
2332                      'InitialSamplingPercentage': 100, 'DestinationS3Uri": 's3://my-bucket/path', \\
2333                      'CaptureOptions': [{'CaptureMode': 'Output'}]}'
2334                      -C env='{"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": "\"--timeout 60\""}' \\
2335                      -C tags='{"training_timestamp": "2022-11-01T05:12:26"}' \\
2336          """
2337          final_config = self._default_deployment_config()
2338          if config:
2339              self._apply_custom_config(final_config, config)
2340  
2341          app_name, flavor = _deploy(
2342              app_name=name,
2343              model_uri=model_uri,
2344              flavor=flavor,
2345              execution_role_arn=final_config["execution_role_arn"],
2346              assume_role_arn=final_config["assume_role_arn"],
2347              bucket=final_config["bucket"],
2348              image_url=final_config["image_url"],
2349              region_name=final_config["region_name"],
2350              mode=mlflow.sagemaker.DEPLOYMENT_MODE_CREATE,
2351              archive=final_config["archive"],
2352              instance_type=final_config["instance_type"],
2353              instance_count=final_config["instance_count"],
2354              vpc_config=final_config["vpc_config"],
2355              data_capture_config=final_config["data_capture_config"],
2356              synchronous=final_config["synchronous"],
2357              timeout_seconds=final_config["timeout_seconds"],
2358              variant_name=final_config["variant_name"],
2359              async_inference_config=final_config["async_inference_config"],
2360              serverless_config=final_config["serverless_config"],
2361              env=final_config["env"],
2362              tags=final_config["tags"],
2363          )
2364  
2365          return {"name": app_name, "flavor": flavor}
2366  
2367      def update_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None):
2368          """
2369          Update a deployment on AWS SageMaker. This function can replace or add a new model to
2370          an existing SageMaker endpoint. By default, this function replaces the existing model
2371          with the new one. The currently active AWS account must have correct permissions set up.
2372  
2373          Args:
2374              name: Name of the deployed application.
2375              model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker.
2376                  For example:
2377  
2378                  - ``/Users/me/path/to/local/model``
2379                  - ``relative/path/to/local/model``
2380                  - ``s3://my_bucket/path/to/model``
2381                  - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
2382                  - ``models:/<model_name>/<model_version>``
2383                  - ``models:/<model_name>/<stage>``
2384  
2385                  For more information about supported URI schemes, see
2386                  `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
2387                  artifact-locations>`_.
2388  
2389              flavor: The name of the flavor of the model to use for deployment. Must be either
2390                  ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS.
2391                  If ``None``, a flavor is automatically selected from the model's available
2392                  flavors. If the specified flavor is not present or not supported for
2393                  deployment, an exception will be thrown.
2394  
2395              config: Configuration parameters. The supported parameters are:
2396  
2397                  - ``assume_role_arn``: The name of an IAM cross-account role to be assumed
2398                    to deploy SageMaker to another AWS account. If this parameter is not
2399                    specified, the role given in the ``target_uri`` will be used. If the
2400                    role is not given in the ``target_uri``, defaults to ``us-west-2``.
2401  
2402                  - ``execution_role_arn``: The name of an IAM role granting the SageMaker
2403                    service permissions to access the specified Docker image and S3 bucket
2404                    containing MLflow model artifacts. If unspecified, the currently-assumed
2405                    role will be used. This execution role is passed to the SageMaker service
2406                    when creating a SageMaker model from the specified MLflow model. It is
2407                    passed as the ``ExecutionRoleArn`` parameter of the `SageMaker
2408                    CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/
2409                    dg/API_CreateModel.html>`_. This role is *not* assumed for any other
2410                    call. For more information about SageMaker execution roles for model
2411                    creation, see
2412                    https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html.
2413  
2414                  - ``bucket``: S3 bucket where model artifacts will be stored. Defaults to a
2415                    SageMaker-compatible bucket name.
2416  
2417                  - ``image_url``: URL of the ECR-hosted Docker image the model should be
2418                    deployed into, produced by ``mlflow sagemaker build-and-push-container``.
2419                    This parameter can also be specified by the environment variable
2420                    ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``.
2421  
2422                  - ``region_name``: Name of the AWS region to which to deploy the application.
2423                    If unspecified, use the region name given in the ``target_uri``.
2424                    If it is also not specified in the ``target_uri``,
2425                    defaults to ``us-west-2``.
2426  
2427                  - ``mode``: The mode in which to deploy the application.
2428                    Must be one of the following:
2429  
2430                    ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE``
2431                        If an application of the specified name exists, its model(s) is
2432                        replaced with the specified model. If no such application exists,
2433                        it is created with the specified name and model.
2434                        This is the default mode.
2435  
2436                    ``mlflow.sagemaker.DEPLOYMENT_MODE_ADD``
2437                        Add the specified model to a pre-existing application with the
2438                        specified name, if one exists. If the application does not exist,
2439                        a new application is created with the specified name and model.
2440                        NOTE: If the application **already exists**, the specified model is
2441                        added to the application's corresponding SageMaker endpoint with an
2442                        initial weight of zero (0). To route traffic to the model,
2443                        update the application's associated endpoint configuration using
2444                        either the AWS console or the ``UpdateEndpointWeightsAndCapacities``
2445                        function defined in https://docs.aws.amazon.com/sagemaker/latest/dg/API_UpdateEndpointWeightsAndCapacities.html.
2446  
2447                  - ``archive``: If ``True``, any pre-existing SageMaker application resources
2448                    that become inactive (i.e. as a result of deploying in
2449                    ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE`` mode) are preserved.
2450                    These resources may include unused SageMaker models and endpoint
2451                    configurations that were associated with a prior version of the
2452                    application endpoint. If ``False``, these resources are deleted.
2453                    In order to use ``archive=False``, ``update_deployment()`` must be executed
2454                    synchronously with ``synchronous=True``. Defaults to ``False``.
2455  
2456                  - ``instance_type``: The type of SageMaker ML instance on which to deploy the
2457                    model. For a list of supported instance types, see
2458                    https://aws.amazon.com/sagemaker/pricing/instance-types/.
2459                    Defaults to ``ml.m4.xlarge``.
2460  
2461                  - ``instance_count``: The number of SageMaker ML instances on which to deploy
2462                    the model. Defaults to ``1``.
2463  
2464                  - ``synchronous``: If ``True``, this function will block until the deployment
2465                    process succeeds or encounters an irrecoverable failure. If ``False``,
2466                    this function will return immediately after starting the deployment
2467                    process. It will not wait for the deployment process to complete;
2468                    in this case, the caller is responsible for monitoring the health and
2469                    status of the pending deployment via native SageMaker APIs or the AWS
2470                    console. Defaults to ``True``.
2471  
2472                  - ``timeout_seconds``: If ``synchronous`` is ``True``, the deployment process
2473                    will return after the specified number of seconds if no definitive result
2474                    (success or failure) is achieved. Once the function returns, the caller is
2475                    responsible for monitoring the health and status of the pending
2476                    deployment using native SageMaker APIs or the AWS console. If
2477                    ``synchronous`` is ``False``, this parameter is ignored.
2478                    Defaults to ``300``.
2479  
2480                  - ``variant_name``: A string specifying the desired name when creating a
2481                    production variant.  Defaults to ``None``.
2482  
2483                  - ``vpc_config``: A dictionary specifying the VPC configuration to use when
2484                    creating the new SageMaker model associated with this application.
2485                    The acceptable values for this parameter are identical to those of the
2486                    ``VpcConfig`` parameter in the `SageMaker boto3 client's create_model
2487                    method <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html
2488                    #SageMaker.Client.create_model>`_. For more information, see
2489                    https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html.
2490                    Defaults to ``None``.
2491  
2492                  - ``data_capture_config``: A dictionary specifying the data capture
2493                    configuration to use when creating the new SageMaker model associated with
2494                    this application.
2495                    For more information, see
2496                    https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html.
2497                    Defaults to ``None``.
2498  
2499                  - ``async_inference_config``: A dictionary specifying the async config
2500                    configuration. Defaults to ``None``.
2501  
2502                  - ``env``: A dictionary specifying environment variables as key-value pairs
2503                    to be set for the deployed model. Defaults to ``None``.
2504  
2505                  - ``tags``: A dictionary of key-value pairs representing additional tags
2506                    to be set for the deployed model. Defaults to ``None``.
2507  
2508              endpoint: (optional) Endpoint containing the deployment to update. Currently unsupported
2509  
2510          .. code-block:: python
2511              :caption: Python example
2512  
2513              from mlflow.deployments import get_deploy_client
2514  
2515              vpc_config = {
2516                  "SecurityGroupIds": [
2517                      "sg-123456abc",
2518                  ],
2519                  "Subnets": [
2520                      "subnet-123456abc",
2521                  ],
2522              }
2523              data_capture_config = {
2524                  "EnableCapture": True,
2525                  "InitialSamplingPercentage": 100,
2526                  "DestinationS3Uri": "s3://my-bucket/path",
2527                  "CaptureOptions": [{"CaptureMode": "Output"}],
2528              }
2529              config = dict(
2530                  assume_role_arn="arn:aws:123:role/assumed_role",
2531                  execution_role_arn="arn:aws:456:role/execution_role",
2532                  bucket_name="my-s3-bucket",
2533                  image_url="1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1",
2534                  region_name="us-east-1",
2535                  mode="replace",
2536                  archive=False,
2537                  instance_type="ml.m5.4xlarge",
2538                  instance_count=1,
2539                  synchronous=True,
2540                  timeout_seconds=300,
2541                  variant_name="prod-variant-1",
2542                  vpc_config=vpc_config,
2543                  data_capture_config=data_capture_config,
2544                  env={"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": '"--timeout 60"'},
2545                  tags={"training_timestamp": "2022-11-01T05:12:26"},
2546              )
2547              client = get_deploy_client("sagemaker")
2548              client.update_deployment(
2549                  "my-deployment",
2550                  model_uri="/mlruns/0/abc/model",
2551                  flavor="python_function",
2552                  config=config,
2553              )
2554          .. code-block:: bash
2555              :caption:  Command-line example
2556  
2557              mlflow deployments update --target sagemaker:/us-east-1/arn:aws:123:role/assumed_role \\
2558                      --name my-deployment \\
2559                      --model-uri /mlruns/0/abc/model \\
2560                      --flavor python_function\\
2561                      -C execution_role_arn=arn:aws:456:role/execution_role \\
2562                      -C bucket_name=my-s3-bucket \\
2563                      -C image_url=1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1 \\
2564                      -C region_name=us-east-1 \\
2565                      -C mode=replace \\
2566                      -C archive=False \\
2567                      -C instance_type=ml.m5.4xlarge \\
2568                      -C instance_count=1 \\
2569                      -C synchronous=True \\
2570                      -C timeout_seconds=300 \\
2571                      -C variant_name=prod-variant-1 \\
2572                      -C vpc_config='{"SecurityGroupIds": ["sg-123456abc"], \\
2573                      "Subnets": ["subnet-123456abc"]}' \\
2574                      -C data_capture_config='{"EnableCapture": True, \\
2575                      "InitialSamplingPercentage": 100, "DestinationS3Uri": "s3://my-bucket/path", \\
2576                      "CaptureOptions": [{"CaptureMode": "Output"}]}'
2577                      -C env='{"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": "\"--timeout 60\""}' \\
2578                      -C tags='{"training_timestamp": "2022-11-01T05:12:26"}' \\
2579          """
2580          final_config = self._default_deployment_config(create_mode=False)
2581          if config:
2582              self._apply_custom_config(final_config, config)
2583  
2584          if model_uri is None:
2585              raise MlflowException(
2586                  message="A model_uri must be provided when updating a SageMaker deployment",
2587                  error_code=INVALID_PARAMETER_VALUE,
2588              )
2589  
2590          if final_config["mode"] not in [DEPLOYMENT_MODE_ADD, DEPLOYMENT_MODE_REPLACE]:
2591              raise MlflowException(
2592                  message=(
2593                      f"Invalid mode `{final_config['mode']}` for deployment"
2594                      " to a pre-existing application"
2595                  ),
2596                  error_code=INVALID_PARAMETER_VALUE,
2597              )
2598  
2599          app_name, flavor = _deploy(
2600              app_name=name,
2601              model_uri=model_uri,
2602              flavor=flavor,
2603              execution_role_arn=final_config["execution_role_arn"],
2604              assume_role_arn=final_config["assume_role_arn"],
2605              bucket=final_config["bucket"],
2606              image_url=final_config["image_url"],
2607              region_name=final_config["region_name"],
2608              mode=final_config["mode"],
2609              archive=final_config["archive"],
2610              instance_type=final_config["instance_type"],
2611              instance_count=final_config["instance_count"],
2612              vpc_config=final_config["vpc_config"],
2613              data_capture_config=final_config["data_capture_config"],
2614              synchronous=final_config["synchronous"],
2615              timeout_seconds=final_config["timeout_seconds"],
2616              variant_name=final_config["variant_name"],
2617              async_inference_config=final_config["async_inference_config"],
2618              serverless_config=final_config["serverless_config"],
2619              env=final_config["env"],
2620              tags=final_config["tags"],
2621          )
2622  
2623          return {"name": app_name, "flavor": flavor}
2624  
2625      def delete_deployment(self, name, config=None, endpoint=None):
2626          """
2627          Delete a SageMaker application.
2628  
2629          Args:
2630              name: Name of the deployed application.
2631              config: Configuration parameters. The supported parameters are:
2632  
2633                  - ``assume_role_arn``: The name of an IAM role to be assumed to delete
2634                    the SageMaker deployment.
2635  
2636                  - ``region_name``: Name of the AWS region in which the application
2637                    is deployed. Defaults to ``us-west-2`` or the region provided in
2638                    the `target_uri`.
2639  
2640                  - ``archive``: If `True`, resources associated with the specified
2641                    application, such as its associated models and endpoint configuration,
2642                    are preserved. If `False`, these resources are deleted. In order to use
2643                    ``archive=False``, ``delete()`` must be executed synchronously with
2644                    ``synchronous=True``. Defaults to ``False``.
2645  
2646                  - ``synchronous``: If `True`, this function blocks until the deletion process
2647                    succeeds or encounters an irrecoverable failure. If `False`, this function
2648                    returns immediately after starting the deletion process. It will not wait
2649                    for the deletion process to complete; in this case, the caller is
2650                    responsible for monitoring the status of the deletion process via native
2651                    SageMaker APIs or the AWS console. Defaults to ``True``.
2652  
2653                  - ``timeout_seconds``: If `synchronous` is `True`, the deletion process
2654                    returns after the specified number of seconds if no definitive result
2655                    (success or failure) is achieved. Once the function returns, the caller
2656                    is responsible for monitoring the status of the deletion process via native
2657                    SageMaker APIs or the AWS console. If `synchronous` is False, this
2658                    parameter is ignored. Defaults to ``300``.
2659  
2660              endpoint: (optional) Endpoint containing the deployment to delete. Currently unsupported
2661  
2662          .. code-block:: python
2663              :caption: Python example
2664  
2665              from mlflow.deployments import get_deploy_client
2666  
2667              config = dict(
2668                  assume_role_arn="arn:aws:123:role/assumed_role",
2669                  region_name="us-east-1",
2670                  archive=False,
2671                  synchronous=True,
2672                  timeout_seconds=300,
2673              )
2674              client = get_deploy_client("sagemaker")
2675              client.delete_deployment("my-deployment", config=config)
2676  
2677          .. code-block:: bash
2678              :caption: Command-line example
2679  
2680              mlflow deployments delete --target sagemaker \\
2681                      --name my-deployment \\
2682                      -C assume_role_arn=arn:aws:123:role/assumed_role \\
2683                      -C region_name=us-east-1 \\
2684                      -C archive=False \\
2685                      -C synchronous=True \\
2686                      -C timeout_seconds=300
2687          """
2688          final_config = {
2689              "region_name": self.region_name,
2690              "archive": False,
2691              "synchronous": True,
2692              "timeout_seconds": 300,
2693              "assume_role_arn": self.assumed_role_arn,
2694          }
2695          if config:
2696              self._apply_custom_config(final_config, config)
2697  
2698          _delete(
2699              name,
2700              region_name=final_config["region_name"],
2701              assume_role_arn=final_config["assume_role_arn"],
2702              archive=final_config["archive"],
2703              synchronous=final_config["synchronous"],
2704              timeout_seconds=final_config["timeout_seconds"],
2705          )
2706  
2707      def list_deployments(self, endpoint=None):
2708          """
2709          List deployments. This method returns a list of dictionaries that describes each deployment.
2710  
2711          If a region name needs to be specified, the plugin must be initialized
2712          with the AWS region in the ``target_uri`` such as ``sagemaker:/us-east-1``.
2713  
2714          To assume an IAM role, the plugin must be initialized
2715          with the AWS region and the role ARN in the ``target_uri`` such as
2716          ``sagemaker:/us-east-1/arn:aws:1234:role/assumed_role``.
2717  
2718          Args:
2719              endpoint: (optional) List deployments in the specified endpoint. Currently unsupported
2720  
2721          Returns:
2722              A list of dictionaries corresponding to deployments.
2723  
2724          .. code-block:: python
2725              :caption: Python example
2726  
2727              from mlflow.deployments import get_deploy_client
2728  
2729              client = get_deploy_client("sagemaker:/us-east-1/arn:aws:123:role/assumed_role")
2730              client.list_deployments()
2731  
2732          .. code-block:: bash
2733              :caption: Command-line example
2734  
2735              mlflow deployments list --target sagemaker:/us-east-1/arn:aws:1234:role/assumed_role
2736          """
2737          import boto3
2738  
2739          assume_role_credentials = _assume_role_and_get_credentials(
2740              assume_role_arn=self.assumed_role_arn
2741          )
2742  
2743          sage_client = boto3.client(
2744              "sagemaker", region_name=self.region_name, **assume_role_credentials
2745          )
2746          return sage_client.list_endpoints()["Endpoints"]
2747  
2748      def get_deployment(self, name, endpoint=None):
2749          """
2750          Returns a dictionary describing the specified deployment.
2751  
2752          If a region name needs to be specified, the plugin must be initialized
2753          with the AWS region in the ``target_uri`` such as ``sagemaker:/us-east-1``.
2754  
2755          To assume an IAM role, the plugin must be initialized
2756          with the AWS region and the role ARN in the ``target_uri`` such as
2757          ``sagemaker:/us-east-1/arn:aws:1234:role/assumed_role``.
2758  
2759          A :py:class:`mlflow.exceptions.MlflowException` will also be thrown when an error occurs
2760          while retrieving the deployment.
2761  
2762          Args:
2763              name: Name of deployment to retrieve
2764              endpoint: (optional) Endpoint containing the deployment to get. Currently unsupported
2765  
2766          Returns:
2767              A dictionary that describes the specified deployment
2768  
2769          .. code-block:: python
2770              :caption: Python example
2771  
2772              from mlflow.deployments import get_deploy_client
2773  
2774              client = get_deploy_client("sagemaker:/us-east-1/arn:aws:123:role/assumed_role")
2775              client.get_deployment("my-deployment")
2776  
2777          .. code-block:: bash
2778              :caption: Command-line example
2779  
2780              mlflow deployments get --target sagemaker:/us-east-1/arn:aws:1234:role/assumed_role \\
2781                  --name my-deployment
2782          """
2783          import boto3
2784  
2785          assume_role_credentials = _assume_role_and_get_credentials(
2786              assume_role_arn=self.assumed_role_arn
2787          )
2788  
2789          try:
2790              sage_client = boto3.client(
2791                  "sagemaker", region_name=self.region_name, **assume_role_credentials
2792              )
2793              return sage_client.describe_endpoint(EndpointName=name)
2794          except Exception as exc:
2795              raise MlflowException(
2796                  message=f"There was an error while retrieving the deployment: {exc}\n"
2797              )
2798  
2799      def predict(
2800          self,
2801          deployment_name=None,
2802          inputs=None,
2803          endpoint=None,
2804          params: dict[str, Any] | None = None,
2805      ):
2806          """
2807          Compute predictions from the specified deployment using the provided PyFunc input.
2808  
2809          The input/output types of this method match the :ref:`MLflow PyFunc prediction
2810          interface <pyfunc-inference-api>`.
2811  
2812          If a region name needs to be specified, the plugin must be initialized
2813          with the AWS region in the ``target_uri`` such as ``sagemaker:/us-east-1``.
2814  
2815          To assume an IAM role, the plugin must be initialized
2816          with the AWS region and the role ARN in the ``target_uri`` such as
2817          ``sagemaker:/us-east-1/arn:aws:1234:role/assumed_role``.
2818  
2819          Args:
2820              deployment_name: Name of the deployment to predict against.
2821              inputs: Input data (or arguments) to pass to the deployment or model endpoint for
2822                  inference. For a complete list of supported input types, see
2823                  :ref:`pyfunc-inference-api`.
2824              endpoint: Endpoint to predict against. Currently unsupported
2825              params: Optional parameters to invoke the endpoint with.
2826  
2827          Returns:
2828              A PyFunc output, such as a Pandas DataFrame, Pandas Series, or NumPy array.
2829              For a complete list of supported output types, see :ref:`pyfunc-inference-api`.
2830  
2831          .. code-block:: python
2832              :caption: Python example
2833  
2834              import pandas as pd
2835              from mlflow.deployments import get_deploy_client
2836  
2837              df = pd.DataFrame(data=[[1, 2, 3]], columns=["feat1", "feat2", "feat3"])
2838              client = get_deploy_client("sagemaker:/us-east-1/arn:aws:123:role/assumed_role")
2839              client.predict("my-deployment", df)
2840  
2841          .. code-block:: bash
2842              :caption: Command-line example
2843  
2844              cat > ./input.json <<- input
2845              {"feat1": {"0": 1}, "feat2": {"0": 2}, "feat3": {"0": 3}}
2846              input
2847  
2848              mlflow deployments predict \\
2849                  --target sagemaker:/us-east-1/arn:aws:1234:role/assumed_role \\
2850                  --name my-deployment \\
2851                  --input-path ./input.json
2852          """
2853          import boto3
2854  
2855          assume_role_credentials = _assume_role_and_get_credentials(
2856              assume_role_arn=self.assumed_role_arn
2857          )
2858  
2859          try:
2860              sage_client = boto3.client(
2861                  "sagemaker-runtime", region_name=self.region_name, **assume_role_credentials
2862              )
2863              response = sage_client.invoke_endpoint(
2864                  EndpointName=deployment_name,
2865                  Body=dump_input_data(inputs, inputs_key="instances", params=params),
2866                  ContentType="application/json",
2867              )
2868              response_body = response["Body"].read().decode("utf-8")
2869              return PredictionsResponse.from_json(response_body)
2870          except Exception as exc:
2871              raise MlflowException(
2872                  message=f"There was an error while getting model prediction: {exc}\n"
2873              )
2874  
2875      def explain(self, deployment_name=None, df=None, endpoint=None):
2876          """
2877          *This function has not been implemented and will be coming in the future.*
2878          """
2879          raise NotImplementedError("This function is not implemented yet.")
2880  
2881      def create_endpoint(self, name, config=None):
2882          """
2883          Create an endpoint with the specified target. By default, this method should block until
2884          creation completes (i.e. until it's possible to create a deployment within the endpoint).
2885          In the case of conflicts (e.g. if it's not possible to create the specified endpoint
2886          due to conflict with an existing endpoint), raises a
2887          :py:class:`mlflow.exceptions.MlflowException`. See target-specific plugin documentation
2888          for additional detail on support for asynchronous creation and other configuration.
2889  
2890          Args:
2891              name: Unique name to use for endpoint. If another endpoint exists with the same
2892                          name, raises a :py:class:`mlflow.exceptions.MlflowException`.
2893              config: (optional) Dict containing target-specific configuration for the endpoint.
2894  
2895          Returns:
2896              Dict corresponding to created endpoint, which must contain the 'name' key.
2897          """
2898          raise NotImplementedError("This function is not implemented yet.")
2899  
2900      def update_endpoint(self, endpoint, config=None):
2901          """
2902          Update the endpoint with the specified name. You can update any target-specific attributes
2903          of the endpoint (via `config`). By default, this method should block until the update
2904          completes (i.e. until it's possible to create a deployment within the endpoint). See
2905          target-specific plugin documentation for additional detail on support for asynchronous
2906          update and other configuration.
2907  
2908          Args:
2909              endpoint: Unique name of endpoint to update
2910              config: (optional) dict containing target-specific configuration for the endpoint
2911          """
2912          raise NotImplementedError("This function is not implemented yet.")
2913  
2914      def delete_endpoint(self, endpoint):
2915          """
2916          Delete the endpoint from the specified target. Deletion should be idempotent (i.e. deletion
2917          should not fail if retried on a non-existent deployment).
2918  
2919          Args:
2920              endpoint: Name of endpoint to delete
2921          """
2922          raise NotImplementedError("This function is not implemented yet.")
2923  
2924      def list_endpoints(self):
2925          """
2926          List endpoints in the specified target. This method is expected to return an
2927          unpaginated list of all endpoints (an alternative would be to return a dict with
2928          an 'endpoints' field containing the actual endpoints, with plugins able to specify
2929          other fields, e.g. a next_page_token field, in the returned dictionary for pagination,
2930          and to accept a `pagination_args` argument to this method for passing
2931          pagination-related args).
2932  
2933          Returns:
2934              A list of dicts corresponding to endpoints. Each dict is guaranteed to
2935              contain a 'name' key containing the endpoint name. The other fields of
2936              the returned dictionary and their types may vary across targets.
2937          """
2938          raise NotImplementedError("This function is not implemented yet.")
2939  
2940      def get_endpoint(self, endpoint):
2941          """
2942          Returns a dictionary describing the specified endpoint, throwing a
2943          py:class:`mlflow.exception.MlflowException` if no endpoint exists with the provided
2944          name.
2945          The dict is guaranteed to contain an 'name' key containing the endpoint name.
2946          The other fields of the returned dictionary and their types may vary across targets.
2947  
2948          Args:
2949              endpoint: Name of endpoint to fetch
2950          """
2951          raise NotImplementedError("This function is not implemented yet.")
2952  
2953  
2954  class _SageMakerOperation:
2955      def __init__(self, status_check_fn, cleanup_fn):
2956          self.status_check_fn = status_check_fn
2957          self.cleanup_fn = cleanup_fn
2958          self.start_time = time.time()
2959          self.status = _SageMakerOperationStatus(_SageMakerOperationStatus.STATE_IN_PROGRESS, None)
2960          self.cleaned_up = False
2961  
2962      def await_completion(self, timeout_seconds):
2963          iteration = 0
2964          begin = time.time()
2965          while (time.time() - begin) < timeout_seconds:
2966              status = self.status_check_fn()
2967              if status.state == _SageMakerOperationStatus.STATE_IN_PROGRESS:
2968                  if iteration % 4 == 0:
2969                      # Log the progress status roughly every 20 seconds
2970                      _logger.info(status.message)
2971  
2972                  time.sleep(5)
2973                  iteration += 1
2974                  continue
2975              else:
2976                  self.status = status
2977                  return status
2978  
2979          duration_seconds = time.time() - begin
2980          return _SageMakerOperationStatus.timed_out(duration_seconds)
2981  
2982      def clean_up(self):
2983          if self.status.state != _SageMakerOperationStatus.STATE_SUCCEEDED:
2984              raise ValueError(
2985                  "Cannot clean up an operation that has not succeeded! Current operation state:"
2986                  f" {self.status.state}"
2987              )
2988  
2989          if not self.cleaned_up:
2990              self.cleaned_up = True
2991          else:
2992              raise ValueError("`clean_up()` has already been executed for this operation!")
2993  
2994          self.cleanup_fn()
2995  
2996  
2997  class _SageMakerOperationStatus:
2998      STATE_SUCCEEDED = "succeeded"
2999      STATE_FAILED = "failed"
3000      STATE_IN_PROGRESS = "in progress"
3001      STATE_TIMED_OUT = "timed_out"
3002  
3003      def __init__(self, state, message):
3004          self.state = state
3005          self.message = message
3006  
3007      @classmethod
3008      def in_progress(cls, message=None):
3009          if message is None:
3010              message = "The operation is still in progress."
3011          return cls(_SageMakerOperationStatus.STATE_IN_PROGRESS, message)
3012  
3013      @classmethod
3014      def timed_out(cls, duration_seconds):
3015          return cls(
3016              _SageMakerOperationStatus.STATE_TIMED_OUT,
3017              f"Timed out after waiting {duration_seconds} seconds for the operation to"
3018              " complete. This operation may still be in progress. Please check the AWS"
3019              " console for more information.",
3020          )
3021  
3022      @classmethod
3023      def failed(cls, message):
3024          return cls(_SageMakerOperationStatus.STATE_FAILED, message)
3025  
3026      @classmethod
3027      def succeeded(cls, message):
3028          return cls(_SageMakerOperationStatus.STATE_SUCCEEDED, message)