__init__.py
1 """ 2 The ``mlflow.sagemaker`` module provides an API for deploying MLflow models to Amazon SageMaker. 3 """ 4 5 import json 6 import logging 7 import os 8 import signal 9 import subprocess 10 import sys 11 import tarfile 12 import time 13 import urllib.parse 14 import uuid 15 from typing import Any 16 17 import mlflow 18 import mlflow.version 19 from mlflow import pyfunc 20 from mlflow.deployments import BaseDeploymentClient, PredictionsResponse 21 from mlflow.environment_variables import ( 22 MLFLOW_DEPLOYMENT_FLAVOR_NAME, 23 MLFLOW_SAGEMAKER_DEPLOY_IMG_URL, 24 ) 25 from mlflow.exceptions import MlflowException 26 from mlflow.models import Model 27 from mlflow.models.container import ( 28 SERVING_ENVIRONMENT, 29 ) 30 from mlflow.models.container import ( 31 SUPPORTED_FLAVORS as SUPPORTED_DEPLOYMENT_FLAVORS, 32 ) 33 from mlflow.models.model import MLMODEL_FILE_NAME 34 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST 35 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 36 from mlflow.utils.file_utils import TempDir 37 from mlflow.utils.proto_json_utils import dump_input_data 38 39 DEFAULT_IMAGE_NAME = "mlflow-pyfunc" 40 DEPLOYMENT_MODE_ADD = "add" 41 DEPLOYMENT_MODE_REPLACE = "replace" 42 DEPLOYMENT_MODE_CREATE = "create" 43 44 DEPLOYMENT_MODES = [DEPLOYMENT_MODE_CREATE, DEPLOYMENT_MODE_ADD, DEPLOYMENT_MODE_REPLACE] 45 46 DEFAULT_BUCKET_NAME_PREFIX = "mlflow-sagemaker" 47 48 DEFAULT_SAGEMAKER_INSTANCE_TYPE = "ml.m4.xlarge" 49 DEFAULT_SAGEMAKER_INSTANCE_COUNT = 1 50 51 DEFAULT_REGION_NAME = "us-west-2" 52 SAGEMAKER_SERVING_ENVIRONMENT = "SageMaker" 53 SAGEMAKER_APP_NAME_TAG_KEY = "app_name" 54 55 _logger = logging.getLogger(__name__) 56 57 _full_template = "{account}.dkr.ecr.{region}.amazonaws.com/{image}:{version}" 58 59 60 def _get_preferred_deployment_flavor(model_config): 61 """ 62 Obtains the flavor that MLflow would prefer to use when deploying the model. 63 If the model does not contain any supported flavors for deployment, an exception 64 will be thrown. 65 66 Args: 67 model_config: An MLflow model object 68 69 Returns: 70 The name of the preferred deployment flavor for the specified model 71 """ 72 if pyfunc.FLAVOR_NAME in model_config.flavors: 73 return pyfunc.FLAVOR_NAME 74 else: 75 raise MlflowException( 76 message=( 77 "The specified model does not contain any of the supported flavors for" 78 " deployment. The model contains the following flavors: {model_flavors}." 79 " Supported flavors: {supported_flavors}".format( 80 model_flavors=model_config.flavors.keys(), 81 supported_flavors=SUPPORTED_DEPLOYMENT_FLAVORS, 82 ) 83 ), 84 error_code=RESOURCE_DOES_NOT_EXIST, 85 ) 86 87 88 def _validate_deployment_flavor(model_config, flavor): 89 """ 90 Checks that the specified flavor is a supported deployment flavor 91 and is contained in the specified model. If one of these conditions 92 is not met, an exception is thrown. 93 94 Args: 95 model_config: An MLflow Model object 96 flavor: The deployment flavor to validate 97 """ 98 if flavor not in SUPPORTED_DEPLOYMENT_FLAVORS: 99 raise MlflowException( 100 message=( 101 f"The specified flavor: `{flavor}` is not supported for deployment." 102 f" Please use one of the supported flavors: {SUPPORTED_DEPLOYMENT_FLAVORS}" 103 ), 104 error_code=INVALID_PARAMETER_VALUE, 105 ) 106 elif flavor not in model_config.flavors: 107 raise MlflowException( 108 message=( 109 "The specified model does not contain the specified deployment flavor:" 110 f" `{flavor}`. Please use one of the following deployment flavors" 111 f" that the model contains: {model_config.flavors.keys()}" 112 ), 113 error_code=RESOURCE_DOES_NOT_EXIST, 114 ) 115 116 117 def push_image_to_ecr(image=DEFAULT_IMAGE_NAME): 118 """ 119 Push local Docker image to AWS ECR. 120 121 The image is pushed under currently active AWS account and to the currently active AWS region. 122 123 Args: 124 image: Docker image name. 125 """ 126 import boto3 127 128 _logger.info("Pushing image to ECR") 129 client = boto3.client("sts") 130 caller_id = client.get_caller_identity() 131 account = caller_id["Account"] 132 my_session = boto3.session.Session() 133 region = my_session.region_name or "us-west-2" 134 fullname = _full_template.format( 135 account=account, region=region, image=image, version=mlflow.version.VERSION 136 ) 137 _logger.info("Pushing docker image %s to %s", image, fullname) 138 ecr_client = boto3.client("ecr") 139 try: 140 ecr_client.describe_repositories(repositoryNames=[image])["repositories"] 141 except ecr_client.exceptions.RepositoryNotFoundException: 142 ecr_client.create_repository(repositoryName=image) 143 _logger.info("Created new ECR repository: %s", image) 144 registry = f"{account}.dkr.ecr.{region}.amazonaws.com" 145 146 try: 147 # Docker login: get password from AWS CLI and pipe to docker login 148 _logger.info("Logging in to ECR registry: %s", registry) 149 aws_result = subprocess.run( 150 ["aws", "ecr", "get-login-password"], 151 capture_output=True, 152 check=True, 153 ) 154 subprocess.run( 155 ["docker", "login", "--username", "AWS", "--password-stdin", registry], 156 input=aws_result.stdout, 157 check=True, 158 ) 159 160 # Docker tag 161 _logger.info("Tagging image %s as %s", image, fullname) 162 subprocess.check_call(["docker", "tag", image, fullname]) 163 164 # Docker push 165 _logger.info("Pushing image %s", fullname) 166 subprocess.check_call(["docker", "push", fullname]) 167 except subprocess.CalledProcessError as e: 168 cmd = " ".join(e.cmd) 169 raise MlflowException( 170 f"Failed to push image to ECR. Command '{cmd}' failed with exit code {e.returncode}" 171 ) from e 172 173 174 def _deploy( 175 app_name, 176 model_uri, 177 execution_role_arn=None, 178 assume_role_arn=None, 179 bucket=None, 180 image_url=None, 181 region_name="us-west-2", 182 mode=DEPLOYMENT_MODE_CREATE, 183 archive=False, 184 instance_type=DEFAULT_SAGEMAKER_INSTANCE_TYPE, 185 instance_count=DEFAULT_SAGEMAKER_INSTANCE_COUNT, 186 vpc_config=None, 187 flavor=None, 188 synchronous=True, 189 timeout_seconds=1200, 190 data_capture_config=None, 191 variant_name=None, 192 async_inference_config=None, 193 serverless_config=None, 194 env=None, 195 tags=None, 196 ): 197 """ 198 Deploy an MLflow model on AWS SageMaker. 199 The currently active AWS account must have correct permissions set up. 200 201 This function creates a SageMaker endpoint. For more information about the input data 202 formats accepted by this endpoint, see the 203 `MLflow deployment tools documentation <../../deployment/deploy-model-to-sagemaker.html>`_. 204 205 Args: 206 app_name: Name of the deployed application. 207 model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker. 208 For example: 209 210 - ``/Users/me/path/to/local/model`` 211 - ``relative/path/to/local/model`` 212 - ``s3://my_bucket/path/to/model`` 213 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 214 - ``models:/<model_name>/<model_version>`` 215 - ``models:/<model_name>/<stage>`` 216 217 For more information about supported URI schemes, see 218 `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html# 219 artifact-locations>`_. 220 221 execution_role_arn: The name of an IAM role granting the SageMaker service permissions to 222 access the specified Docker image and S3 bucket containing MLflow 223 model artifacts. If unspecified, the currently-assumed role will be 224 used. This execution role is passed to the SageMaker service when 225 creating a SageMaker model from the specified MLflow model. It is 226 passed as the ``ExecutionRoleArn`` parameter of the `SageMaker 227 CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/ 228 dg/API_CreateModel.html>`_. This role is *not* assumed for any other 229 call. For more information about SageMaker execution roles for model 230 creation, see 231 https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html. 232 assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker 233 to another AWS account. If unspecified, SageMaker will be deployed to 234 the the currently active AWS account. 235 bucket: S3 bucket where model artifacts will be stored. Defaults to a 236 SageMaker-compatible bucket name. 237 image_url: URL of the ECR-hosted Docker image the model should be deployed into, produced 238 by ``mlflow sagemaker build-and-push-container``. This parameter can also 239 be specified by the environment variable ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``. 240 region_name: Name of the AWS region to which to deploy the application. 241 mode: The mode in which to deploy the application. Must be one of the following: 242 243 ``mlflow.sagemaker.DEPLOYMENT_MODE_CREATE`` 244 Create an application with the specified name and model. This fails if an 245 application of the same name already exists. 246 247 ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE`` 248 If an application of the specified name exists, its model(s) is replaced with 249 the specified model. If no such application exists, it is created with the 250 specified name and model. 251 252 ``mlflow.sagemaker.DEPLOYMENT_MODE_ADD`` 253 Add the specified model to a pre-existing application with the specified name, 254 if one exists. If the application does not exist, a new application is created 255 with the specified name and model. NOTE: If the application **already exists**, 256 the specified model is added to the application's corresponding SageMaker 257 endpoint with an initial weight of zero (0). To route traffic to the model, 258 update the application's associated endpoint configuration using either the 259 AWS console or the ``UpdateEndpointWeightsAndCapacities`` function defined in 260 https://docs.aws.amazon.com/sagemaker/latest/dg/API_UpdateEndpointWeightsAndCapacities.html. 261 262 archive: If ``True``, any pre-existing SageMaker application resources that become 263 inactive (i.e. as a result of deploying in 264 ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE`` mode) are preserved. 265 These resources may include unused SageMaker models and endpoint configurations 266 that were associated with a prior version of the application endpoint. If 267 ``False``, these resources are deleted. In order to use ``archive=False``, 268 ``deploy()`` must be executed synchronously with ``synchronous=True``. 269 instance_type: The type of SageMaker ML instance on which to deploy the model. For a list 270 of supported instance types, see 271 https://aws.amazon.com/sagemaker/pricing/instance-types/. 272 instance_count: The number of SageMaker ML instances on which to deploy the model. 273 vpc_config: A dictionary specifying the VPC configuration to use when creating the 274 new SageMaker model associated with this application. The acceptable values 275 for this parameter are identical to those of the ``VpcConfig`` parameter in 276 the `SageMaker boto3 client's create_model method 277 <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html 278 #SageMaker.Client.create_model>`_. For more information, see 279 https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html. 280 281 .. code-block:: python 282 :caption: Example 283 284 import mlflow.sagemaker as mfs 285 286 vpc_config = { 287 "SecurityGroupIds": [ 288 "sg-123456abc", 289 ], 290 "Subnets": [ 291 "subnet-123456abc", 292 ], 293 } 294 mfs._deploy(..., vpc_config=vpc_config) 295 296 flavor: The name of the flavor of the model to use for deployment. Must be either 297 ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS. If ``None``, 298 a flavor is automatically selected from the model's available flavors. If the 299 specified flavor is not present or not supported for deployment, an exception 300 will be thrown. 301 synchronous: If ``True``, this function will block until the deployment process succeeds 302 or encounters an irrecoverable failure. If ``False``, this function will 303 return immediately after starting the deployment process. It will not wait 304 for the deployment process to complete; in this case, the caller is 305 responsible for monitoring the health and status of the pending deployment 306 via native SageMaker APIs or the AWS console. 307 timeout_seconds: If ``synchronous`` is ``True``, the deployment process will return after 308 the specified number of seconds if no definitive result (success or 309 failure) is achieved. Once the function returns, the caller is 310 responsible for monitoring the health and status of the pending 311 deployment using native SageMaker APIs or the AWS console. If 312 ``synchronous`` is ``False``, this parameter is ignored. 313 data_capture_config: A dictionary specifying the data capture configuration to use when 314 creating the new SageMaker model associated with this application. 315 For more information, see 316 https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html. 317 318 .. code-block:: python 319 :caption: Example 320 321 import mlflow.sagemaker as mfs 322 323 data_capture_config = { 324 "EnableCapture": True, 325 "InitialSamplingPercentage": 100, 326 "DestinationS3Uri": "s3://my-bucket/path", 327 "CaptureOptions": [{"CaptureMode": "Output"}], 328 } 329 mfs._deploy(..., data_capture_config=data_capture_config) 330 331 variant_name: The name to assign to the new production variant. 332 async_inference_config: The name to assign to the endpoint_config 333 on the sagemaker endpoint. 334 .. code-block:: python 335 :caption: Example 336 337 { 338 "AsyncInferenceConfig": { 339 "ClientConfig": {"MaxConcurrentInvocationsPerInstance": 4}, 340 "OutputConfig": { 341 "S3OutputPath": "s3://<path-to-output-bucket>", 342 "NotificationConfig": {}, 343 }, 344 } 345 } 346 347 serverless_config: An optional dictionary specifying the serverless configuration 348 .. code-block:: python 349 :caption: Example 350 351 { 352 "ServerlessConfig": { 353 "MemorySizeInMB": 2048, 354 "MaxConcurrency": 20, 355 } 356 } 357 358 env: An optional dictionary of environment variables to set for the model. 359 tags: An optional dictionary of tags to apply to the endpoint. 360 """ 361 import boto3 362 363 if (not archive) and (not synchronous): 364 raise MlflowException( 365 message=( 366 "Resources must be archived when `deploy()` is executed in non-synchronous mode." 367 " Either set `synchronous=True` or `archive=True`." 368 ), 369 error_code=INVALID_PARAMETER_VALUE, 370 ) 371 372 if mode not in DEPLOYMENT_MODES: 373 raise MlflowException( 374 message="`mode` must be one of: {deployment_modes}".format( 375 deployment_modes=",".join(DEPLOYMENT_MODES) 376 ), 377 error_code=INVALID_PARAMETER_VALUE, 378 ) 379 380 model_path = _download_artifact_from_uri(model_uri) 381 model_config_path = os.path.join(model_path, MLMODEL_FILE_NAME) 382 if not os.path.exists(model_config_path): 383 raise MlflowException( 384 message=( 385 f"Failed to find {MLMODEL_FILE_NAME} configuration within the specified model's " 386 "root directory." 387 ), 388 error_code=INVALID_PARAMETER_VALUE, 389 ) 390 model_config = Model.load(model_config_path) 391 392 if flavor is None: 393 flavor = _get_preferred_deployment_flavor(model_config) 394 else: 395 _validate_deployment_flavor(model_config, flavor) 396 _logger.info("Using the %s flavor for deployment!", flavor) 397 398 assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn) 399 400 s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials) 401 sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials) 402 403 endpoint_exists = _find_endpoint(endpoint_name=app_name, sage_client=sage_client) is not None 404 if endpoint_exists and mode == DEPLOYMENT_MODE_CREATE: 405 raise MlflowException( 406 message=( 407 f"You are attempting to deploy an application with name: {app_name} in" 408 f" '{DEPLOYMENT_MODE_CREATE}' mode. However, an application with the same name" 409 " already exists. If you want to update this application, deploy in" 410 f" '{DEPLOYMENT_MODE_ADD}' or '{DEPLOYMENT_MODE_REPLACE}' mode." 411 ), 412 error_code=INVALID_PARAMETER_VALUE, 413 ) 414 415 model_name = _get_sagemaker_model_name(endpoint_name=app_name) 416 417 if not image_url: 418 image_url = _get_default_image_url(region_name=region_name) 419 if not execution_role_arn: 420 execution_role_arn = _get_assumed_role_arn(**assume_role_credentials) 421 if not bucket: 422 _logger.info("No model data bucket specified, using the default bucket") 423 bucket = _get_default_s3_bucket(region_name, **assume_role_credentials) 424 425 model_s3_path = _upload_s3( 426 local_model_path=model_path, 427 bucket=bucket, 428 prefix=model_name, 429 region_name=region_name, 430 s3_client=s3_client, 431 **assume_role_credentials, 432 ) 433 434 if endpoint_exists: 435 deployment_operation = _update_sagemaker_endpoint( 436 endpoint_name=app_name, 437 model_name=model_name, 438 model_s3_path=model_s3_path, 439 model_uri=model_uri, 440 image_url=image_url, 441 flavor=flavor, 442 instance_type=instance_type, 443 instance_count=instance_count, 444 vpc_config=vpc_config, 445 mode=mode, 446 role=execution_role_arn, 447 sage_client=sage_client, 448 s3_client=s3_client, 449 variant_name=variant_name, 450 async_inference_config=async_inference_config, 451 serverless_config=serverless_config, 452 data_capture_config=data_capture_config, 453 env=env, 454 tags=tags, 455 ) 456 else: 457 deployment_operation = _create_sagemaker_endpoint( 458 endpoint_name=app_name, 459 model_name=model_name, 460 model_s3_path=model_s3_path, 461 model_uri=model_uri, 462 image_url=image_url, 463 flavor=flavor, 464 instance_type=instance_type, 465 instance_count=instance_count, 466 vpc_config=vpc_config, 467 data_capture_config=data_capture_config, 468 role=execution_role_arn, 469 sage_client=sage_client, 470 variant_name=variant_name, 471 async_inference_config=async_inference_config, 472 serverless_config=serverless_config, 473 env=env, 474 tags=tags, 475 ) 476 477 if synchronous: 478 _logger.info("Waiting for the deployment operation to complete...") 479 operation_status = deployment_operation.await_completion(timeout_seconds=timeout_seconds) 480 if operation_status.state == _SageMakerOperationStatus.STATE_SUCCEEDED: 481 _logger.info( 482 'The deployment operation completed successfully with message: "%s"', 483 operation_status.message, 484 ) 485 else: 486 raise MlflowException( 487 "The deployment operation failed with the following error message:" 488 f' "{operation_status.message}"' 489 ) 490 if not archive: 491 deployment_operation.clean_up() 492 493 return app_name, flavor 494 495 496 def _delete( 497 app_name, 498 region_name="us-west-2", 499 assume_role_arn=None, 500 archive=False, 501 synchronous=True, 502 timeout_seconds=300, 503 ): 504 """ 505 Delete a SageMaker application. 506 507 Args: 508 app_name: Name of the deployed application. 509 region_name: Name of the AWS region in which the application is deployed. 510 assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker 511 to another AWS account. If unspecified, SageMaker will be deployed to 512 the the currently active AWS account. 513 archive: If ``True``, resources associated with the specified application, such 514 as its associated models and endpoint configuration, are preserved. 515 If ``False``, these resources are deleted. In order to use 516 ``archive=False``, ``delete()`` must be executed synchronously with 517 ``synchronous=True``. 518 synchronous: If `True`, this function blocks until the deletion process succeeds 519 or encounters an irrecoverable failure. If `False`, this function 520 returns immediately after starting the deletion process. It will not wait 521 for the deletion process to complete; in this case, the caller is 522 responsible for monitoring the status of the deletion process via native 523 SageMaker APIs or the AWS console. 524 timeout_seconds: If `synchronous` is `True`, the deletion process returns after the 525 specified number of seconds if no definitive result (success or failure) 526 is achieved. Once the function returns, the caller is responsible 527 for monitoring the status of the deletion process via native SageMaker 528 APIs or the AWS console. If `synchronous` is False, this parameter 529 is ignored. 530 """ 531 import boto3 532 533 if (not archive) and (not synchronous): 534 raise MlflowException( 535 message=( 536 "Resources must be archived when `delete()` is executed in non-synchronous mode." 537 " Either set `synchronous=True` or `archive=True`." 538 ), 539 error_code=INVALID_PARAMETER_VALUE, 540 ) 541 542 assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn) 543 544 s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials) 545 sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials) 546 547 endpoint_info = sage_client.describe_endpoint(EndpointName=app_name) 548 endpoint_arn = endpoint_info["EndpointArn"] 549 550 sage_client.delete_endpoint(EndpointName=app_name) 551 _logger.info("Deleted endpoint with arn: %s", endpoint_arn) 552 553 def status_check_fn(): 554 endpoint_info = _find_endpoint(endpoint_name=app_name, sage_client=sage_client) 555 if endpoint_info is not None: 556 return _SageMakerOperationStatus.in_progress( 557 "Deletion is still in progress. Current endpoint status: {endpoint_status}".format( 558 endpoint_status=endpoint_info["EndpointStatus"] 559 ) 560 ) 561 else: 562 return _SageMakerOperationStatus.succeeded( 563 "The SageMaker endpoint was deleted successfully." 564 ) 565 566 def cleanup_fn(): 567 _logger.info("Cleaning up unused resources...") 568 config_name = endpoint_info["EndpointConfigName"] 569 config_info = sage_client.describe_endpoint_config(EndpointConfigName=config_name) 570 config_arn = config_info["EndpointConfigArn"] 571 sage_client.delete_endpoint_config(EndpointConfigName=config_name) 572 _logger.info("Deleted associated endpoint configuration with arn: %s", config_arn) 573 for pv in config_info["ProductionVariants"]: 574 model_name = pv["ModelName"] 575 model_arn = _delete_sagemaker_model(model_name, sage_client, s3_client) 576 _logger.info("Deleted associated model with arn: %s", model_arn) 577 578 delete_operation = _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn) 579 580 if synchronous: 581 _logger.info("Waiting for the delete operation to complete...") 582 operation_status = delete_operation.await_completion(timeout_seconds=timeout_seconds) 583 if operation_status.state == _SageMakerOperationStatus.STATE_SUCCEEDED: 584 _logger.info( 585 'The deletion operation completed successfully with message: "%s"', 586 operation_status.message, 587 ) 588 else: 589 raise MlflowException( 590 "The deletion operation failed with the following error message:" 591 f' "{operation_status.message}"' 592 ) 593 if not archive: 594 delete_operation.clean_up() 595 596 597 def deploy_transform_job( 598 job_name, 599 model_uri, 600 s3_input_data_type, 601 s3_input_uri, 602 content_type, 603 s3_output_path, 604 compression_type="None", 605 split_type="Line", 606 accept="text/csv", 607 assemble_with="Line", 608 input_filter="$", 609 output_filter="$", 610 join_resource="None", 611 execution_role_arn=None, 612 assume_role_arn=None, 613 bucket=None, 614 image_url=None, 615 region_name="us-west-2", 616 instance_type=DEFAULT_SAGEMAKER_INSTANCE_TYPE, 617 instance_count=DEFAULT_SAGEMAKER_INSTANCE_COUNT, 618 vpc_config=None, 619 flavor=None, 620 archive=False, 621 synchronous=True, 622 timeout_seconds=1200, 623 ): 624 """ 625 Deploy an MLflow model on AWS SageMaker and create the corresponding batch transform job. 626 The currently active AWS account must have correct permissions set up. 627 628 Args: 629 job_name: Name of the deployed Sagemaker batch transform job. 630 model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker. 631 For example: 632 633 - ``/Users/me/path/to/local/model`` 634 - ``relative/path/to/local/model`` 635 - ``s3://my_bucket/path/to/model`` 636 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 637 - ``models:/<model_name>/<model_version>`` 638 - ``models:/<model_name>/<stage>`` 639 640 For more information about supported URI schemes, see 641 `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html# 642 artifact-locations>`_. 643 644 s3_input_data_type: Input data type for the transform job. 645 s3_input_uri: S3 key name prefix or a manifest of the input data. 646 content_type: The multipurpose internet mail extension (MIME) type of the data. 647 s3_output_path: The S3 path to store the output results of the Sagemaker transform job. 648 compression_type: The compression type of the transform data. 649 split_type: The method to split the transform job's data files into smaller batches. 650 accept: The multipurpose internet mail extension (MIME) type of the output data. 651 assemble_with: The method to assemble the results of the transform job as 652 a single S3 object. 653 input_filter: A JSONPath expression used to select a portion of the input data for 654 the transform job. 655 output_filter: A JSONPath expression used to select a portion of the output data from 656 the transform job. 657 join_resource: The source of the data to join with the transformed data. 658 659 execution_role_arn: The name of an IAM role granting the SageMaker service permissions to 660 access the specified Docker image and S3 bucket containing MLflow 661 model artifacts. If unspecified, the currently-assumed role will be 662 used. This execution role is passed to the SageMaker service when 663 creating a SageMaker model from the specified MLflow model. It is 664 passed as the ``ExecutionRoleArn`` parameter of the `SageMaker 665 CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/ 666 dg/API_CreateModel.html>`_. This role is *not* assumed for any other 667 call. For more information about SageMaker execution roles for model 668 creation, see 669 https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html. 670 assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker 671 to another AWS account. If unspecified, SageMaker will be deployed to 672 the the currently active AWS account. 673 bucket: S3 bucket where model artifacts will be stored. Defaults to a 674 SageMaker-compatible bucket name. 675 image_url: URL of the ECR-hosted Docker image the model should be deployed into, produced 676 by ``mlflow sagemaker build-and-push-container``. This parameter can also 677 be specified by the environment variable ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``. 678 region_name: Name of the AWS region to which to deploy the application. 679 instance_type: The type of SageMaker ML instance on which to deploy the model. For a list 680 of supported instance types, see 681 https://aws.amazon.com/sagemaker/pricing/instance-types/. 682 instance_count: The number of SageMaker ML instances on which to deploy the model. 683 vpc_config: A dictionary specifying the VPC configuration to use when creating the 684 new SageMaker model associated with this batch transform job. The acceptable 685 values for this parameter are identical to those of the ``VpcConfig`` 686 parameter in the `SageMaker boto3 client's create_model method 687 <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html 688 #SageMaker.Client.create_model>`_. For more information, see 689 https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html. 690 691 .. code-block:: python 692 :caption: Example 693 694 import mlflow.sagemaker as mfs 695 696 vpc_config = { 697 "SecurityGroupIds": [ 698 "sg-123456abc", 699 ], 700 "Subnets": [ 701 "subnet-123456abc", 702 ], 703 } 704 mfs.deploy_transform_job(..., vpc_config=vpc_config) 705 706 flavor: The name of the flavor of the model to use for deployment. Must be either 707 ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS. If ``None``, 708 a flavor is automatically selected from the model's available flavors. If the 709 specified flavor is not present or not supported for deployment, an exception 710 will be thrown. 711 archive: If ``True``, resources like Sagemaker models and model artifacts in S3 are 712 preserved after the finished batch transform job. If ``False``, these resources 713 are deleted. In order to use ``archive=False``, ``deploy_transform_job()`` must 714 be executed synchronously with ``synchronous=True``. 715 synchronous: If ``True``, this function will block until the deployment process succeeds 716 or encounters an irrecoverable failure. If ``False``, this function will 717 return immediately after starting the deployment process. It will not wait 718 for the deployment process to complete; in this case, the caller is 719 responsible for monitoring the health and status of the pending deployment 720 via native SageMaker APIs or the AWS console. 721 timeout_seconds: If ``synchronous`` is ``True``, the deployment process will return after 722 the specified number of seconds if no definitive result (success or 723 failure) is achieved. Once the function returns, the caller is 724 responsible for monitoring the health and status of the pending 725 deployment using native SageMaker APIs or the AWS console. If 726 ``synchronous`` is ``False``, this parameter is ignored. 727 """ 728 import boto3 729 730 if (not archive) and (not synchronous): 731 raise MlflowException( 732 message=( 733 "Resources must be archived when `deploy_transform_job()`" 734 " is executed in non-synchronous mode." 735 " Either set `synchronous=True` or `archive=True`." 736 ), 737 error_code=INVALID_PARAMETER_VALUE, 738 ) 739 740 model_path = _download_artifact_from_uri(model_uri) 741 model_config_path = os.path.join(model_path, MLMODEL_FILE_NAME) 742 if not os.path.exists(model_config_path): 743 raise MlflowException( 744 message=( 745 f"Failed to find {MLMODEL_FILE_NAME} configuration within the specified model's" 746 " root directory." 747 ), 748 error_code=INVALID_PARAMETER_VALUE, 749 ) 750 model_config = Model.load(model_config_path) 751 752 if flavor is None: 753 flavor = _get_preferred_deployment_flavor(model_config) 754 else: 755 _validate_deployment_flavor(model_config, flavor) 756 _logger.info("Using the %s flavor for deployment!", flavor) 757 758 assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn) 759 760 s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials) 761 sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials) 762 763 transform_job_exists = ( 764 _find_transform_job(job_name=job_name, sage_client=sage_client) is not None 765 ) 766 if transform_job_exists: 767 raise MlflowException( 768 message=( 769 f"You are attempting to deploy a batch transform job with name: {job_name}. " 770 "However, a batch transform job with the same name already exists." 771 ), 772 error_code=INVALID_PARAMETER_VALUE, 773 ) 774 775 model_name = _get_sagemaker_transform_model_name(job_name=job_name) 776 if not image_url: 777 image_url = _get_default_image_url(region_name=region_name) 778 if not execution_role_arn: 779 execution_role_arn = _get_assumed_role_arn(**assume_role_credentials) 780 if not bucket: 781 _logger.info("No model data bucket specified, using the default bucket") 782 bucket = _get_default_s3_bucket(region_name, **assume_role_credentials) 783 784 model_s3_path = _upload_s3( 785 local_model_path=model_path, 786 bucket=bucket, 787 prefix=model_name, 788 region_name=region_name, 789 s3_client=s3_client, 790 **assume_role_credentials, 791 ) 792 793 deployment_operation = _create_sagemaker_transform_job( 794 job_name=job_name, 795 model_name=model_name, 796 model_s3_path=model_s3_path, 797 model_uri=model_uri, 798 image_url=image_url, 799 flavor=flavor, 800 vpc_config=vpc_config, 801 role=execution_role_arn, 802 sage_client=sage_client, 803 s3_client=s3_client, 804 instance_type=instance_type, 805 instance_count=instance_count, 806 s3_input_data_type=s3_input_data_type, 807 s3_input_uri=s3_input_uri, 808 content_type=content_type, 809 compression_type=compression_type, 810 split_type=split_type, 811 s3_output_path=s3_output_path, 812 accept=accept, 813 assemble_with=assemble_with, 814 input_filter=input_filter, 815 output_filter=output_filter, 816 join_resource=join_resource, 817 ) 818 819 if synchronous: 820 _logger.info("Waiting for the batch transform job to complete...") 821 operation_status = deployment_operation.await_completion(timeout_seconds=timeout_seconds) 822 if operation_status.state == _SageMakerOperationStatus.STATE_SUCCEEDED: 823 _logger.info( 824 'The batch transform job completed successfully with message: "%s"', 825 operation_status.message, 826 ) 827 else: 828 raise MlflowException( 829 "The batch transform job failed with the following error message:" 830 f' "{operation_status.message}"' 831 ) 832 if not archive: 833 deployment_operation.clean_up() 834 835 836 def terminate_transform_job( 837 job_name, 838 region_name="us-west-2", 839 assume_role_arn=None, 840 archive=False, 841 synchronous=True, 842 timeout_seconds=300, 843 ): 844 """ 845 Terminate a SageMaker batch transform job. 846 847 Args: 848 job_name: Name of the deployed Sagemaker batch transform job. 849 region_name: Name of the AWS region in which the batch transform job is deployed. 850 assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker 851 to another AWS account. If unspecified, SageMaker will be deployed to 852 the the currently active AWS account. 853 archive: If ``True``, resources associated with the specified batch transform job, 854 such as its associated models and model artifacts, are preserved. 855 If ``False``, these resources are deleted. In order to use ``archive=False``, 856 ``terminate_transform_job()`` must be executed synchronously 857 with ``synchronous=True``. 858 synchronous: If `True`, this function blocks until the termination process succeeds 859 or encounters an irrecoverable failure. If `False`, this function 860 returns immediately after starting the termination process. It will not 861 wait for the termination process to complete; in this case, the caller is 862 responsible for monitoring the status of the termination process via native 863 SageMaker APIs or the AWS console. 864 timeout_seconds: If `synchronous` is `True`, the termination process returns after the 865 specified number of seconds if no definitive result (success or failure) 866 is achieved. Once the function returns, the caller is responsible 867 for monitoring the status of the termination process via native 868 SageMaker APIs or the AWS console. If `synchronous` is False, this 869 parameter is ignored. 870 """ 871 import boto3 872 873 if (not archive) and (not synchronous): 874 raise MlflowException( 875 message=( 876 "Resources must be archived when `terminate_transform_job()`" 877 " is executed in non-synchronous mode." 878 " Either set `synchronous=True` or `archive=True`." 879 ), 880 error_code=INVALID_PARAMETER_VALUE, 881 ) 882 883 assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn) 884 885 s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials) 886 sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials) 887 888 transform_job_info = sage_client.describe_transform_job(TransformJobName=job_name) 889 transform_job_arn = transform_job_info["TransformJobArn"] 890 891 sage_client.stop_transform_job(TransformJobName=job_name) 892 _logger.info("Terminated batch transform job with arn: %s", transform_job_arn) 893 894 def status_check_fn(): 895 transform_job_info = _find_transform_job(job_name=job_name, sage_client=sage_client) 896 897 if transform_job_info["TransformJobStatus"] == "Stopping": 898 return _SageMakerOperationStatus.in_progress( 899 "Termination is still in progress. Current batch transform job status: " 900 "{transform_job_status}".format( 901 transform_job_status=transform_job_info["TransformJobStatus"] 902 ) 903 ) 904 elif transform_job_info["TransformJobStatus"] == "Stopped": 905 return _SageMakerOperationStatus.succeeded( 906 "The SageMaker batch transform job was terminated successfully." 907 ) 908 909 def cleanup_fn(): 910 _logger.info("Cleaning up unused resources...") 911 model_name = transform_job_info["ModelName"] 912 model_arn = _delete_sagemaker_model(model_name, sage_client, s3_client) 913 _logger.info("Deleted associated model with arn: %s", model_arn) 914 915 stop_operation = _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn) 916 917 if synchronous: 918 _logger.info("Waiting for the termination operation to complete...") 919 operation_status = stop_operation.await_completion(timeout_seconds=timeout_seconds) 920 if operation_status.state == _SageMakerOperationStatus.STATE_SUCCEEDED: 921 _logger.info( 922 'The termination operation completed successfully with message: "%s"', 923 operation_status.message, 924 ) 925 else: 926 raise MlflowException( 927 "The termination operation failed with the following error message:" 928 f' "{operation_status.message}"' 929 ) 930 if not archive: 931 stop_operation.clean_up() 932 933 934 def push_model_to_sagemaker( 935 model_name, 936 model_uri, 937 execution_role_arn=None, 938 assume_role_arn=None, 939 bucket=None, 940 image_url=None, 941 region_name="us-west-2", 942 vpc_config=None, 943 flavor=None, 944 ): 945 """ 946 Create a SageMaker Model from an MLflow model artifact. 947 The currently active AWS account must have correct permissions set up. 948 949 Args: 950 model_name: Name of the Sagemaker model. 951 model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker. 952 For example: 953 954 - ``/Users/me/path/to/local/model`` 955 - ``relative/path/to/local/model`` 956 - ``s3://my_bucket/path/to/model`` 957 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 958 - ``models:/<model_name>/<model_version>`` 959 - ``models:/<model_name>/<stage>`` 960 961 For more information about supported URI schemes, see 962 `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html# 963 artifact-locations>`_. 964 965 execution_role_arn: The name of an IAM role granting the SageMaker service permissions to 966 access the specified Docker image and S3 bucket containing MLflow 967 model artifacts. If unspecified, the currently-assumed role will be 968 used. This execution role is passed to the SageMaker service when 969 creating a SageMaker model from the specified MLflow model. It is 970 passed as the ``ExecutionRoleArn`` parameter of the `SageMaker 971 CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/ 972 dg/API_CreateModel.html>`_. This role is *not* assumed for any other 973 call. For more information about SageMaker execution roles for model 974 creation, see 975 https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html. 976 assume_role_arn: The name of an IAM cross-account role to be assumed to deploy SageMaker 977 to another AWS account. If unspecified, SageMaker will be deployed to 978 the the currently active AWS account. 979 bucket: S3 bucket where model artifacts will be stored. Defaults to a 980 SageMaker-compatible bucket name. 981 image_url: URL of the ECR-hosted Docker image the model should be deployed into, produced 982 by ``mlflow sagemaker build-and-push-container``. This parameter can also 983 be specified by the environment variable ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``. 984 region_name: Name of the AWS region to which to deploy the application. 985 vpc_config: A dictionary specifying the VPC configuration to use when creating the 986 new SageMaker model. The acceptable values for this parameter are identical 987 to those of the ``VpcConfig`` parameter in the `SageMaker boto3 client's 988 create_model method 989 <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html 990 #SageMaker.Client.create_model>`_. For more information, see 991 https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html. 992 993 .. code-block:: python 994 :caption: Example 995 996 import mlflow.sagemaker as mfs 997 998 vpc_config = { 999 "SecurityGroupIds": [ 1000 "sg-123456abc", 1001 ], 1002 "Subnets": [ 1003 "subnet-123456abc", 1004 ], 1005 } 1006 mfs.push_model_to_sagemaker(..., vpc_config=vpc_config) 1007 1008 flavor: The name of the flavor of the model to use for deployment. Must be either 1009 ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS. If ``None``, 1010 a flavor is automatically selected from the model's available flavors. If the 1011 specified flavor is not present or not supported for deployment, an exception 1012 will be thrown. 1013 """ 1014 import boto3 1015 1016 model_path = _download_artifact_from_uri(model_uri) 1017 model_config_path = os.path.join(model_path, MLMODEL_FILE_NAME) 1018 if not os.path.exists(model_config_path): 1019 raise MlflowException( 1020 message=( 1021 f"Failed to find {MLMODEL_FILE_NAME} configuration within the specified model's" 1022 " root directory." 1023 ), 1024 error_code=INVALID_PARAMETER_VALUE, 1025 ) 1026 model_config = Model.load(model_config_path) 1027 1028 if flavor is None: 1029 flavor = _get_preferred_deployment_flavor(model_config) 1030 else: 1031 _validate_deployment_flavor(model_config, flavor) 1032 _logger.info("Using the %s flavor for deployment!", flavor) 1033 1034 assume_role_credentials = _assume_role_and_get_credentials(assume_role_arn=assume_role_arn) 1035 1036 s3_client = boto3.client("s3", region_name=region_name, **assume_role_credentials) 1037 sage_client = boto3.client("sagemaker", region_name=region_name, **assume_role_credentials) 1038 1039 if _does_model_exist(model_name=model_name, sage_client=sage_client): 1040 raise MlflowException( 1041 message=( 1042 f"You are attempting to create a Sagemaker model with name: {model_name}. " 1043 "However, a model with the same name already exists." 1044 ), 1045 error_code=INVALID_PARAMETER_VALUE, 1046 ) 1047 1048 if not image_url: 1049 image_url = _get_default_image_url(region_name=region_name) 1050 if not execution_role_arn: 1051 execution_role_arn = _get_assumed_role_arn(**assume_role_credentials) 1052 if not bucket: 1053 _logger.info("No model data bucket specified, using the default bucket") 1054 bucket = _get_default_s3_bucket(region_name, **assume_role_credentials) 1055 1056 model_s3_path = _upload_s3( 1057 local_model_path=model_path, 1058 bucket=bucket, 1059 prefix=model_name, 1060 region_name=region_name, 1061 s3_client=s3_client, 1062 **assume_role_credentials, 1063 ) 1064 1065 model_response = _create_sagemaker_model( 1066 model_name=model_name, 1067 model_s3_path=model_s3_path, 1068 model_uri=model_uri, 1069 flavor=flavor, 1070 vpc_config=vpc_config, 1071 image_url=image_url, 1072 execution_role=execution_role_arn, 1073 sage_client=sage_client, 1074 env={}, 1075 tags={}, 1076 ) 1077 1078 _logger.info("Created Sagemaker model with arn: %s", model_response["ModelArn"]) 1079 1080 1081 def run_local(name, model_uri, flavor=None, config=None): 1082 """ 1083 Serve the model locally in a SageMaker compatible Docker container. 1084 1085 Note that models deployed locally cannot be managed by other deployment APIs 1086 (e.g. ``update_deployment``, ``delete_deployment``, etc). 1087 1088 Args: 1089 name: Name of the local serving application. 1090 model_uri: The location, in URI format, of the MLflow model to deploy locally. 1091 For example: 1092 1093 - ``/Users/me/path/to/local/model`` 1094 - ``relative/path/to/local/model`` 1095 - ``s3://my_bucket/path/to/model`` 1096 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 1097 - ``models:/<model_name>/<model_version>`` 1098 - ``models:/<model_name>/<stage>`` 1099 1100 For more information about supported URI schemes, see 1101 `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html# 1102 artifact-locations>`_. 1103 flavor: The name of the flavor of the model to use for deployment. Must be either 1104 ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS. 1105 If ``None``, a flavor is automatically selected from the model's available 1106 flavors. If the specified flavor is not present or not supported for 1107 deployment, an exception will be thrown. 1108 config: Configuration parameters. The supported parameters are: 1109 1110 - ``image``: The name of the Docker image to use for model serving. Defaults 1111 to ``"mlflow-pyfunc"``. 1112 - ``port``: The port at which to expose the model server on the local host. 1113 Defaults to ``5000``. 1114 1115 .. code-block:: python 1116 :caption: Python example 1117 1118 from mlflow.models import build_docker 1119 from mlflow.deployments import get_deploy_client 1120 1121 build_docker(name="mlflow-pyfunc") 1122 1123 client = get_deploy_client("sagemaker") 1124 client.run_local( 1125 name="my-local-deployment", 1126 model_uri="/mlruns/0/abc/model", 1127 flavor="python_function", 1128 config={ 1129 "port": 5000, 1130 "image": "mlflow-pyfunc", 1131 }, 1132 ) 1133 1134 .. code-block:: bash 1135 :caption: Command-line example 1136 1137 mlflow models build-docker --name "mlflow-pyfunc" 1138 mlflow deployments run-local --target sagemaker \\ 1139 --name my-local-deployment \\ 1140 --model-uri "/mlruns/0/abc/model" \\ 1141 --flavor python_function \\ 1142 -C port=5000 \\ 1143 -C image="mlflow-pyfunc" 1144 """ 1145 model_path = _download_artifact_from_uri(model_uri) 1146 model_config_path = os.path.join(model_path, MLMODEL_FILE_NAME) 1147 model_config = Model.load(model_config_path) 1148 1149 if flavor is None: 1150 flavor = _get_preferred_deployment_flavor(model_config) 1151 else: 1152 _validate_deployment_flavor(model_config, flavor) 1153 _logger.info("Using the %s flavor for local serving!", flavor) 1154 1155 image = config.get("image", DEFAULT_IMAGE_NAME) 1156 port = int(config.get("port", 5000)) 1157 1158 deployment_config = _get_deployment_config(flavor_name=flavor) 1159 1160 _logger.info("launching docker image with path %s", model_path) 1161 cmd = ["docker", "run", "-v", f"{model_path}:/opt/ml/model/", "-p", f"{port}:8080"] 1162 for key, value in deployment_config.items(): 1163 cmd += ["-e", f"{key}={value}"] 1164 cmd += ["--rm", image, "serve"] 1165 _logger.info("executing: %s", " ".join(cmd)) 1166 proc = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr, text=True) 1167 1168 def _sigterm_handler(*_): 1169 _logger.info("received termination signal => killing docker process") 1170 proc.send_signal(signal.SIGINT) 1171 1172 signal.signal(signal.SIGTERM, _sigterm_handler) 1173 proc.wait() 1174 1175 1176 def target_help(): 1177 """ 1178 Provide help information for the SageMaker deployment client. 1179 """ 1180 return """\ 1181 For detailed documentation on the SageMaker deployment client, please visit 1182 https://mlflow.org/docs/latest/python_api/mlflow.sagemaker.html#mlflow.sagemaker.SageMakerDeploymentClient 1183 1184 The target URI must follow the following formats: 1185 - sagemaker 1186 - sagemaker:/region_name 1187 - sagemaker:/region_name/assume_role_arn 1188 1189 When the region_name or assume_role_arn are provided, they will be used as the default region 1190 and assumed role ARN when executing the commands. 1191 1192 The `create` and `update` commands require a deployment name and a model_uri. The model flavor 1193 and deployment configuration can be optionally provided. These commands can also be executed 1194 in synchronous or asynchronous mode. 1195 1196 The `delete` command accepts configurations to archive a model instead of deleting, execute 1197 in asynchronous mode and timeout period. 1198 """ 1199 1200 1201 def _get_default_image_url(region_name): 1202 import boto3 1203 1204 if env_img := MLFLOW_SAGEMAKER_DEPLOY_IMG_URL.get(): 1205 return env_img 1206 1207 ecr_client = boto3.client("ecr", region_name=region_name) 1208 repository_conf = ecr_client.describe_repositories(repositoryNames=[DEFAULT_IMAGE_NAME])[ 1209 "repositories" 1210 ][0] 1211 return (repository_conf["repositoryUri"] + ":{version}").format(version=mlflow.version.VERSION) 1212 1213 1214 def _get_account_id(**assume_role_credentials): 1215 import boto3 1216 1217 sess = boto3.Session() 1218 sts_client = sess.client("sts", **assume_role_credentials) 1219 identity_info = sts_client.get_caller_identity() 1220 return identity_info["Account"] 1221 1222 1223 def _get_assumed_role_arn(**assume_role_credentials): 1224 """ 1225 Returns: 1226 ARN of the user's current IAM role. 1227 """ 1228 import boto3 1229 1230 sess = boto3.Session() 1231 sts_client = sess.client("sts", **assume_role_credentials) 1232 identity_info = sts_client.get_caller_identity() 1233 sts_arn = identity_info["Arn"] 1234 role_name = sts_arn.split("/")[1] 1235 iam_client = sess.client("iam", **assume_role_credentials) 1236 role_response = iam_client.get_role(RoleName=role_name) 1237 return role_response["Role"]["Arn"] 1238 1239 1240 def _assume_role_and_get_credentials(assume_role_arn=None): 1241 """ 1242 Assume a new role in AWS and return the credentials for that role. 1243 When ``assume_role_arn`` is ``None`` or an empty string, 1244 this function does nothing and returns an empty dictionary. 1245 1246 Args: 1247 assume_role_arn: Optional ARN of the role that will be assumed 1248 1249 Returns: 1250 Dict with credentials of the assumed role 1251 """ 1252 import boto3 1253 1254 if not assume_role_arn: 1255 return {} 1256 1257 sts_client = boto3.client("sts") 1258 sts_response = sts_client.assume_role( 1259 RoleArn=assume_role_arn, RoleSessionName="mlflow-sagemaker" 1260 ) 1261 1262 _logger.info("Assuming role %s for deployment!", assume_role_arn) 1263 1264 return { 1265 "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"], 1266 "aws_secret_access_key": sts_response["Credentials"]["SecretAccessKey"], 1267 "aws_session_token": sts_response["Credentials"]["SessionToken"], 1268 } 1269 1270 1271 def _get_default_s3_bucket(region_name, **assume_role_credentials): 1272 import boto3 1273 1274 # create bucket if it does not exist 1275 sess = boto3.Session() 1276 account_id = _get_account_id(**assume_role_credentials) 1277 bucket_name = f"{DEFAULT_BUCKET_NAME_PREFIX}-{region_name}-{account_id}" 1278 s3 = sess.client("s3", **assume_role_credentials) 1279 response = s3.list_buckets() 1280 buckets = [b["Name"] for b in response["Buckets"]] 1281 if bucket_name not in buckets: 1282 _logger.info("Default bucket `%s` not found. Creating...", bucket_name) 1283 bucket_creation_kwargs = { 1284 "ACL": "bucket-owner-full-control", 1285 "Bucket": bucket_name, 1286 } 1287 if region_name != "us-east-1": 1288 # The location constraint is required during bucket creation for all regions 1289 # outside of us-east-1. This constraint cannot be specified in us-east-1; 1290 # specifying it in this region results in a failure, so we will only 1291 # add it if we are deploying outside of us-east-1. 1292 # See https://docs.aws.amazon.com/cli/latest/reference/s3api/create-bucket.html#examples 1293 bucket_creation_kwargs["CreateBucketConfiguration"] = { 1294 "LocationConstraint": region_name 1295 } 1296 response = s3.create_bucket(**bucket_creation_kwargs) 1297 _logger.info("Bucket creation response: %s", response) 1298 else: 1299 _logger.info("Default bucket `%s` already exists. Skipping creation.", bucket_name) 1300 return bucket_name 1301 1302 1303 def _make_tarfile(output_filename, source_dir): 1304 """ 1305 create a tar.gz from a directory. 1306 """ 1307 with tarfile.open(output_filename, "w:gz") as tar: 1308 for f in os.listdir(source_dir): 1309 tar.add(os.path.join(source_dir, f), arcname=f) 1310 1311 1312 def _upload_s3(local_model_path, bucket, prefix, region_name, s3_client, **assume_role_credentials): 1313 """ 1314 Upload dir to S3 as .tar.gz. 1315 1316 Args: 1317 local_model_path: Local path to a dir. 1318 bucket: S3 bucket where to store the data. 1319 prefix: Path within the bucket. 1320 region_name: The AWS region in which to upload data to S3. 1321 s3_client: A boto3 client for S3. 1322 1323 Returns: 1324 S3 path of the uploaded artifact. 1325 """ 1326 import boto3 1327 1328 sess = boto3.Session(region_name=region_name, **assume_role_credentials) 1329 with TempDir() as tmp: 1330 model_data_file = tmp.path("model.tar.gz") 1331 _make_tarfile(model_data_file, local_model_path) 1332 with open(model_data_file, "rb") as fobj: 1333 key = os.path.join(prefix, "model.tar.gz") 1334 obj = sess.resource("s3").Bucket(bucket).Object(key) 1335 obj.upload_fileobj(fobj) 1336 response = s3_client.put_object_tagging( 1337 Bucket=bucket, Key=key, Tagging={"TagSet": [{"Key": "SageMaker", "Value": "true"}]} 1338 ) 1339 _logger.info("tag response: %s", response) 1340 return f"s3://{bucket}/{key}" 1341 1342 1343 def _get_deployment_config(flavor_name, env_override=None): 1344 """ 1345 Returns: 1346 The deployment configuration as a dictionary 1347 """ 1348 deployment_config = { 1349 MLFLOW_DEPLOYMENT_FLAVOR_NAME.name: flavor_name, 1350 SERVING_ENVIRONMENT: SAGEMAKER_SERVING_ENVIRONMENT, 1351 } 1352 if env_override: 1353 deployment_config.update(env_override) 1354 1355 if os.environ.get("http_proxy") is not None: 1356 deployment_config.update({"http_proxy": os.environ["http_proxy"]}) 1357 1358 if os.environ.get("https_proxy") is not None: 1359 deployment_config.update({"https_proxy": os.environ["https_proxy"]}) 1360 1361 if os.environ.get("no_proxy") is not None: 1362 deployment_config.update({"no_proxy": os.environ["no_proxy"]}) 1363 1364 return deployment_config 1365 1366 1367 def _truncate_name(name, max_length): 1368 # NB: Sagemaker prevents the registration of models and configurations whose names 1369 # exceed 63 characters in length. For reference: 1370 # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_Model.html 1371 # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_TransformJob.html 1372 # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ModelConfiguration.html 1373 # This function middle-truncates the name provided to 1374 # ensure that the least critical name information is not lost 1375 if len(name) <= max_length: 1376 return name 1377 available_length = max_length - 3 1378 start_len = available_length // 2 1379 end_len = available_length - start_len 1380 truncated_name = f"{name[:start_len]}---{name[-end_len:]}" 1381 _logger.warning( 1382 f"Truncated name {name} to {truncated_name} to coerce total character counts to < 64" 1383 ) 1384 return truncated_name 1385 1386 1387 def _get_unique_name(base_name, unique_suffix, unique_id_length=20): 1388 unique_id = uuid.uuid4().hex[:unique_id_length] 1389 unique_resource_string = f"{unique_suffix}{unique_id}" 1390 max_length = 63 - len(unique_resource_string) 1391 return _truncate_name(base_name, max_length) + unique_resource_string 1392 1393 1394 def _get_sagemaker_model_name(endpoint_name): 1395 return _get_unique_name(endpoint_name, "-model-") 1396 1397 1398 def _get_sagemaker_transform_model_name(job_name): 1399 return _get_unique_name(job_name, "-model-") 1400 1401 1402 def _get_sagemaker_config_name(endpoint_name): 1403 return _get_unique_name(endpoint_name, "-config-") 1404 1405 1406 def _get_sagemaker_config_tags(endpoint_name): 1407 return [{"Key": SAGEMAKER_APP_NAME_TAG_KEY, "Value": endpoint_name}] 1408 1409 1410 def _prepare_sagemaker_tags( 1411 config_tags: list[dict[str, str]], 1412 sagemaker_tags: dict[str, str] | None = None, 1413 ): 1414 if not sagemaker_tags: 1415 return config_tags 1416 1417 if SAGEMAKER_APP_NAME_TAG_KEY in sagemaker_tags: 1418 raise MlflowException.invalid_parameter_value( 1419 f"Duplicate tag provided for '{SAGEMAKER_APP_NAME_TAG_KEY}'" 1420 ) 1421 parsed = [{"Key": key, "Value": str(value)} for key, value in sagemaker_tags.items()] 1422 1423 return config_tags + parsed 1424 1425 1426 def _create_sagemaker_transform_job( 1427 job_name, 1428 model_name, 1429 model_s3_path, 1430 model_uri, 1431 image_url, 1432 flavor, 1433 vpc_config, 1434 role, 1435 sage_client, 1436 s3_client, 1437 instance_type, 1438 instance_count, 1439 s3_input_data_type, 1440 s3_input_uri, 1441 content_type, 1442 compression_type, 1443 split_type, 1444 s3_output_path, 1445 accept, 1446 assemble_with, 1447 input_filter, 1448 output_filter, 1449 join_resource, 1450 ): 1451 """ 1452 Args: 1453 job_name: Name of the deployed Sagemaker batch transform job. 1454 model_name: The name to assign the new SageMaker model that will be associated with the 1455 specified batch transform job. 1456 model_s3_path: S3 path where we stored the model artifacts. 1457 model_uri: URI of the MLflow model to associate with the specified SageMaker batch 1458 transform job. 1459 image_url: URL of the ECR-hosted docker image the model is being deployed into. 1460 flavor: The name of the flavor of the model to use for deployment. 1461 vpc_config: A dictionary specifying the VPC configuration to use when creating the 1462 new SageMaker model associated with this SageMaker batch transform job. 1463 role: SageMaker execution ARN role. 1464 sage_client: A boto3 client for SageMaker. 1465 s3_client: A boto3 client for S3. 1466 instance_type: The type of SageMaker ML instance on which to deploy the model. 1467 instance_count: The number of SageMaker ML instances on which to deploy the model. 1468 s3_input_data_type: Input data type for the transform job. 1469 s3_input_uri: S3 key name prefix or a manifest of the input data. 1470 content_type: The multipurpose internet mail extension (MIME) type of the data. 1471 compression_type: The compression type of the transform data. 1472 split_type: The method to split the transform job's data files into smaller batches. 1473 s3_output_path: The S3 path to store the output results of the Sagemaker transform job. 1474 accept: The multipurpose internet mail extension (MIME) type of the output data. 1475 assemble_with: The method to assemble the results of the transform job as a single 1476 S3 object. 1477 input_filter: A JSONPath expression used to select a portion of the input data for the 1478 transform job. 1479 output_filter: A JSONPath expression used to select a portion of the output data from 1480 the transform job. 1481 join_resource: The source of the data to join with the transformed data. 1482 """ 1483 _logger.info("Creating new batch transform job with name: %s ...", job_name) 1484 1485 model_response = _create_sagemaker_model( 1486 model_name=model_name, 1487 model_s3_path=model_s3_path, 1488 model_uri=model_uri, 1489 flavor=flavor, 1490 vpc_config=vpc_config, 1491 image_url=image_url, 1492 execution_role=role, 1493 sage_client=sage_client, 1494 env={}, 1495 tags={}, 1496 ) 1497 _logger.info("Created model with arn: %s", model_response["ModelArn"]) 1498 1499 transform_input = { 1500 "DataSource": {"S3DataSource": {"S3DataType": s3_input_data_type, "S3Uri": s3_input_uri}}, 1501 "ContentType": content_type, 1502 "CompressionType": compression_type, 1503 "SplitType": split_type, 1504 } 1505 1506 transform_output = { 1507 "S3OutputPath": s3_output_path, 1508 "Accept": accept, 1509 "AssembleWith": assemble_with, 1510 } 1511 1512 transform_resources = {"InstanceType": instance_type, "InstanceCount": instance_count} 1513 1514 data_processing = { 1515 "InputFilter": input_filter, 1516 "OutputFilter": output_filter, 1517 "JoinSource": join_resource, 1518 } 1519 1520 transform_job_response = sage_client.create_transform_job( 1521 TransformJobName=job_name, 1522 ModelName=model_name, 1523 TransformInput=transform_input, 1524 TransformOutput=transform_output, 1525 TransformResources=transform_resources, 1526 DataProcessing=data_processing, 1527 Tags=[{"Key": "model_name", "Value": model_name}], 1528 ) 1529 _logger.info( 1530 "Created batch transform job with arn: %s", transform_job_response["TransformJobArn"] 1531 ) 1532 1533 def status_check_fn(): 1534 transform_job_info = sage_client.describe_transform_job(TransformJobName=job_name) 1535 1536 if transform_job_info is None: 1537 return _SageMakerOperationStatus.in_progress( 1538 "Waiting for batch transform job to be created..." 1539 ) 1540 1541 transform_job_status = transform_job_info["TransformJobStatus"] 1542 if transform_job_status == "InProgress": 1543 return _SageMakerOperationStatus.in_progress( 1544 'Waiting for batch transform job to reach the "Completed" state. ' 1545 f' Current batch transform job status: "{transform_job_status}"' 1546 ) 1547 elif transform_job_status == "Completed": 1548 return _SageMakerOperationStatus.succeeded( 1549 "The SageMaker batch transform job was processed successfully." 1550 ) 1551 else: 1552 failure_reason = transform_job_info.get( 1553 "FailureReason", 1554 "An unknown SageMaker failure occurred. Please see the SageMaker console logs" 1555 " for more information.", 1556 ) 1557 return _SageMakerOperationStatus.failed(failure_reason) 1558 1559 def cleanup_fn(): 1560 _logger.info("Cleaning up Sagemaker model and S3 model artifacts...") 1561 transform_job_info = sage_client.describe_transform_job(TransformJobName=job_name) 1562 model_name = transform_job_info["ModelName"] 1563 model_arn = _delete_sagemaker_model(model_name, sage_client, s3_client) 1564 _logger.info("Deleted associated model with arn: %s", model_arn) 1565 1566 return _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn) 1567 1568 1569 def _create_sagemaker_endpoint( 1570 endpoint_name, 1571 model_name, 1572 model_s3_path, 1573 model_uri, 1574 image_url, 1575 flavor, 1576 instance_type, 1577 vpc_config, 1578 data_capture_config, 1579 instance_count, 1580 role, 1581 sage_client, 1582 variant_name=None, 1583 async_inference_config=None, 1584 serverless_config=None, 1585 env=None, 1586 tags=None, 1587 ): 1588 """ 1589 Args: 1590 endpoint_name: The name of the SageMaker endpoint to create. 1591 model_name: The name to assign the new SageMaker model that will be associated with the 1592 specified endpoint. 1593 model_s3_path: S3 path where we stored the model artifacts. 1594 model_uri: URI of the MLflow model to associate with the specified SageMaker endpoint. 1595 image_url: URL of the ECR-hosted docker image the model is being deployed into. 1596 flavor: The name of the flavor of the model to use for deployment. 1597 instance_type: The type of SageMaker ML instance on which to deploy the model. 1598 instance_count: The number of SageMaker ML instances on which to deploy the model. 1599 vpc_config: A dictionary specifying the VPC configuration to use when creating the 1600 new SageMaker model associated with this SageMaker endpoint. 1601 data_capture_config: A dictionary specifying the data capture configuration to use when 1602 creating the new SageMaker model associated with this application. 1603 role: SageMaker execution ARN role. 1604 sage_client: A boto3 client for SageMaker. 1605 variant_name: The name to assign to the new production variant. 1606 env: A dictionary of environment variables to set for the model. 1607 tags: A dictionary of tags to apply to the endpoint. 1608 """ 1609 _logger.info("Creating new endpoint with name: %s ...", endpoint_name) 1610 1611 model_response = _create_sagemaker_model( 1612 model_name=model_name, 1613 model_s3_path=model_s3_path, 1614 model_uri=model_uri, 1615 flavor=flavor, 1616 vpc_config=vpc_config, 1617 image_url=image_url, 1618 execution_role=role, 1619 sage_client=sage_client, 1620 env=env or {}, 1621 tags=tags or {}, 1622 ) 1623 _logger.info("Created model with arn: %s", model_response["ModelArn"]) 1624 1625 if not variant_name: 1626 variant_name = model_name 1627 1628 production_variant = { 1629 "VariantName": variant_name, 1630 "ModelName": model_name, 1631 "InitialVariantWeight": 1, 1632 } 1633 1634 if serverless_config: 1635 production_variant["ServerlessConfig"] = serverless_config 1636 else: 1637 production_variant["InstanceType"] = instance_type 1638 production_variant["InitialInstanceCount"] = instance_count 1639 1640 config_name = _get_sagemaker_config_name(endpoint_name) 1641 config_tags = _get_sagemaker_config_tags(endpoint_name) 1642 tags_list = _prepare_sagemaker_tags(config_tags, tags) 1643 endpoint_config_kwargs = { 1644 "EndpointConfigName": config_name, 1645 "ProductionVariants": [production_variant], 1646 "Tags": config_tags, 1647 } 1648 if async_inference_config: 1649 endpoint_config_kwargs["AsyncInferenceConfig"] = async_inference_config 1650 if data_capture_config is not None: 1651 endpoint_config_kwargs["DataCaptureConfig"] = data_capture_config 1652 endpoint_config_response = sage_client.create_endpoint_config(**endpoint_config_kwargs) 1653 _logger.info( 1654 "Created endpoint configuration with arn: %s", endpoint_config_response["EndpointConfigArn"] 1655 ) 1656 1657 endpoint_response = sage_client.create_endpoint( 1658 EndpointName=endpoint_name, 1659 EndpointConfigName=config_name, 1660 Tags=tags_list or [], 1661 ) 1662 _logger.info("Created endpoint with arn: %s", endpoint_response["EndpointArn"]) 1663 1664 def status_check_fn(): 1665 endpoint_info = _find_endpoint(endpoint_name=endpoint_name, sage_client=sage_client) 1666 1667 if endpoint_info is None: 1668 return _SageMakerOperationStatus.in_progress("Waiting for endpoint to be created...") 1669 1670 endpoint_status = endpoint_info["EndpointStatus"] 1671 if endpoint_status == "Creating": 1672 return _SageMakerOperationStatus.in_progress( 1673 'Waiting for endpoint to reach the "InService" state. Current endpoint status:' 1674 f' "{endpoint_status}"' 1675 ) 1676 elif endpoint_status == "InService": 1677 return _SageMakerOperationStatus.succeeded( 1678 "The SageMaker endpoint was created successfully." 1679 ) 1680 else: 1681 failure_reason = endpoint_info.get( 1682 "FailureReason", 1683 "An unknown SageMaker failure occurred. Please see the SageMaker console logs" 1684 " for more information.", 1685 ) 1686 return _SageMakerOperationStatus.failed(failure_reason) 1687 1688 def cleanup_fn(): 1689 pass 1690 1691 return _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn) 1692 1693 1694 def _update_sagemaker_endpoint( 1695 endpoint_name, 1696 model_name, 1697 model_uri, 1698 image_url, 1699 model_s3_path, 1700 flavor, 1701 instance_type, 1702 instance_count, 1703 vpc_config, 1704 mode, 1705 role, 1706 sage_client, 1707 s3_client, 1708 variant_name=None, 1709 async_inference_config=None, 1710 serverless_config=None, 1711 data_capture_config=None, 1712 env=None, 1713 tags=None, 1714 ): 1715 """ 1716 Args: 1717 endpoint_name: The name of the SageMaker endpoint to update. 1718 model_name: The name to assign the new SageMaker model that will be associated with the 1719 specified endpoint. 1720 model_uri: URI of the MLflow model to associate with the specified SageMaker endpoint. 1721 image_url: URL of the ECR-hosted Docker image the model is being deployed into 1722 model_s3_path: S3 path where we stored the model artifacts 1723 flavor: The name of the flavor of the model to use for deployment. 1724 instance_type: The type of SageMaker ML instance on which to deploy the model. 1725 instance_count: The number of SageMaker ML instances on which to deploy the model. 1726 vpc_config: A dictionary specifying the VPC configuration to use when creating the 1727 new SageMaker model associated with this SageMaker endpoint. 1728 mode: either mlflow.sagemaker.DEPLOYMENT_MODE_ADD or 1729 mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE. 1730 role: SageMaker execution ARN role. 1731 sage_client: A boto3 client for SageMaker. 1732 s3_client: A boto3 client for S3. 1733 variant_name: The name to assign to the new production variant if it doesn't already exist. 1734 async_inference_config: A dictionary specifying the async inference configuration to use. 1735 For more information, see https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AsyncInferenceConfig.html. 1736 Defaults to ``None``. 1737 data_capture_config: A dictionary specifying the data capture configuration to use. 1738 For more information, see https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html. 1739 Defaults to ``None``. 1740 env: A dictionary of environment variables to set for the model. 1741 tags: A dictionary of tags to apply to the endpoint configuration. 1742 """ 1743 if mode not in [DEPLOYMENT_MODE_ADD, DEPLOYMENT_MODE_REPLACE]: 1744 msg = f"Invalid mode `{mode}` for deployment to a pre-existing application" 1745 raise ValueError(msg) 1746 1747 endpoint_info = sage_client.describe_endpoint(EndpointName=endpoint_name) 1748 endpoint_arn = endpoint_info["EndpointArn"] 1749 deployed_config_name = endpoint_info["EndpointConfigName"] 1750 deployed_config_info = sage_client.describe_endpoint_config( 1751 EndpointConfigName=deployed_config_name 1752 ) 1753 deployed_config_arn = deployed_config_info["EndpointConfigArn"] 1754 deployed_production_variants = deployed_config_info["ProductionVariants"] 1755 1756 _logger.info("Found active endpoint with arn: %s. Updating...", endpoint_arn) 1757 1758 new_model_response = _create_sagemaker_model( 1759 model_name=model_name, 1760 model_s3_path=model_s3_path, 1761 model_uri=model_uri, 1762 flavor=flavor, 1763 vpc_config=vpc_config, 1764 image_url=image_url, 1765 execution_role=role, 1766 sage_client=sage_client, 1767 env=env or {}, 1768 tags=tags or {}, 1769 ) 1770 _logger.info("Created new model with arn: %s", new_model_response["ModelArn"]) 1771 1772 if not variant_name: 1773 variant_name = model_name 1774 1775 if mode == DEPLOYMENT_MODE_ADD: 1776 new_model_weight = 0 1777 production_variants = deployed_production_variants 1778 elif mode == DEPLOYMENT_MODE_REPLACE: 1779 new_model_weight = 1 1780 production_variants = [] 1781 1782 new_production_variant = { 1783 "VariantName": variant_name, 1784 "ModelName": model_name, 1785 "InitialVariantWeight": new_model_weight, 1786 } 1787 1788 if serverless_config: 1789 new_production_variant["ServerlessConfig"] = serverless_config 1790 else: 1791 new_production_variant["InstanceType"] = instance_type 1792 new_production_variant["InitialInstanceCount"] = instance_count 1793 1794 production_variants.append(new_production_variant) 1795 1796 # Create the new endpoint configuration and update the endpoint 1797 # to adopt the new configuration 1798 new_config_name = _get_sagemaker_config_name(endpoint_name) 1799 config_tags = _get_sagemaker_config_tags(endpoint_name) 1800 endpoint_config_kwargs = { 1801 "EndpointConfigName": new_config_name, 1802 "ProductionVariants": production_variants, 1803 "Tags": config_tags, 1804 } 1805 if async_inference_config: 1806 endpoint_config_kwargs["AsyncInferenceConfig"] = async_inference_config 1807 if data_capture_config is not None: 1808 endpoint_config_kwargs["DataCaptureConfig"] = data_capture_config 1809 endpoint_config_response = sage_client.create_endpoint_config(**endpoint_config_kwargs) 1810 _logger.info( 1811 "Created new endpoint configuration with arn: %s", 1812 endpoint_config_response["EndpointConfigArn"], 1813 ) 1814 1815 sage_client.update_endpoint(EndpointName=endpoint_name, EndpointConfigName=new_config_name) 1816 _logger.info("Updated endpoint with new configuration!") 1817 1818 operation_start_time = time.time() 1819 1820 def status_check_fn(): 1821 if time.time() - operation_start_time < 20: 1822 # Wait at least 20 seconds before checking the status of the update; this ensures 1823 # that we don't consider the operation to have failed if small delays occur at 1824 # initialization time 1825 return _SageMakerOperationStatus.in_progress() 1826 1827 endpoint_info = sage_client.describe_endpoint(EndpointName=endpoint_name) 1828 endpoint_update_was_rolled_back = ( 1829 endpoint_info["EndpointStatus"] == "InService" 1830 and endpoint_info["EndpointConfigName"] != new_config_name 1831 ) 1832 if endpoint_update_was_rolled_back or endpoint_info["EndpointStatus"] == "Failed": 1833 failure_reason = endpoint_info.get( 1834 "FailureReason", 1835 "An unknown SageMaker failure occurred." 1836 " Please see the SageMaker console logs for" 1837 " more information.", 1838 ) 1839 return _SageMakerOperationStatus.failed(failure_reason) 1840 elif endpoint_info["EndpointStatus"] == "InService": 1841 return _SageMakerOperationStatus.succeeded( 1842 "The SageMaker endpoint was updated successfully." 1843 ) 1844 else: 1845 return _SageMakerOperationStatus.in_progress( 1846 "The update operation is still in progress. Current endpoint status:" 1847 ' "{endpoint_status}"'.format(endpoint_status=endpoint_info["EndpointStatus"]) 1848 ) 1849 1850 def cleanup_fn(): 1851 _logger.info("Cleaning up unused resources...") 1852 if mode == DEPLOYMENT_MODE_REPLACE: 1853 for pv in deployed_production_variants: 1854 deployed_model_arn = _delete_sagemaker_model( 1855 model_name=pv["ModelName"], sage_client=sage_client, s3_client=s3_client 1856 ) 1857 _logger.info("Deleted model with arn: %s", deployed_model_arn) 1858 1859 sage_client.delete_endpoint_config(EndpointConfigName=deployed_config_name) 1860 _logger.info("Deleted endpoint configuration with arn: %s", deployed_config_arn) 1861 1862 return _SageMakerOperation(status_check_fn=status_check_fn, cleanup_fn=cleanup_fn) 1863 1864 1865 def _create_sagemaker_model( 1866 model_name, 1867 model_s3_path, 1868 model_uri, 1869 flavor, 1870 vpc_config, 1871 image_url, 1872 execution_role, 1873 sage_client, 1874 env, 1875 tags, 1876 ): 1877 """ 1878 Args: 1879 model_name: The name to assign the new SageMaker model that is created. 1880 model_s3_path: S3 path where the model artifacts are stored. 1881 model_uri: URI of the MLflow model associated with the new SageMaker model. 1882 flavor: The name of the flavor of the model. 1883 vpc_config: A dictionary specifying the VPC configuration to use when creating the 1884 new SageMaker model associated with this SageMaker endpoint. 1885 image_url: URL of the ECR-hosted Docker image that will serve as the 1886 model's container, 1887 execution_role: The ARN of the role that SageMaker will assume when creating the model. 1888 sage_client: A boto3 client for SageMaker. 1889 env: A dictionary of environment variables to set for the model. 1890 tags: A dictionary of tags to apply to the SageMaker model. 1891 1892 Returns: 1893 AWS response containing metadata associated with the new model. 1894 """ 1895 tags["model_uri"] = str(model_uri) 1896 create_model_args = { 1897 "ModelName": model_name, 1898 "PrimaryContainer": { 1899 "Image": image_url, 1900 "ModelDataUrl": model_s3_path, 1901 "Environment": _get_deployment_config(flavor_name=flavor, env_override=env), 1902 }, 1903 "ExecutionRoleArn": execution_role, 1904 "Tags": [{"Key": key, "Value": str(value)} for key, value in tags.items()], 1905 } 1906 if vpc_config is not None: 1907 create_model_args["VpcConfig"] = vpc_config 1908 1909 return sage_client.create_model(**create_model_args) 1910 1911 1912 def _delete_sagemaker_model(model_name, sage_client, s3_client): 1913 """ 1914 Args: 1915 sage_client: A boto3 client for SageMaker. 1916 s3_client: A boto3 client for S3. 1917 1918 Returns: 1919 ARN of the deleted model. 1920 """ 1921 model_info = sage_client.describe_model(ModelName=model_name) 1922 model_arn = model_info["ModelArn"] 1923 model_data_url = model_info["PrimaryContainer"]["ModelDataUrl"] 1924 1925 # Parse the model data url to obtain a bucket path. The following 1926 # procedure is safe due to the well-documented structure of the `ModelDataUrl` 1927 # (see https://docs.aws.amazon.com/sagemaker/latest/dg/API_ContainerDefinition.html) 1928 parsed_data_url = urllib.parse.urlparse(model_data_url) 1929 bucket_name = parsed_data_url.netloc 1930 bucket_key = parsed_data_url.path.lstrip("/") 1931 1932 s3_client.delete_object(Bucket=bucket_name, Key=bucket_key) 1933 sage_client.delete_model(ModelName=model_name) 1934 1935 return model_arn 1936 1937 1938 def _delete_sagemaker_endpoint_configuration(endpoint_config_name, sage_client): 1939 """ 1940 Args: 1941 sage_client: A boto3 client for SageMaker. 1942 1943 Returns: 1944 ARN of the deleted endpoint configuration. 1945 """ 1946 endpoint_config_info = sage_client.describe_endpoint_config( 1947 EndpointConfigName=endpoint_config_name 1948 ) 1949 sage_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) 1950 return endpoint_config_info["EndpointConfigArn"] 1951 1952 1953 def _find_endpoint(endpoint_name, sage_client): 1954 """ 1955 Finds a SageMaker endpoint with the specified name in the caller's AWS account, returning a 1956 NoneType if the endpoint is not found. 1957 1958 Args: 1959 sage_client: A boto3 client for SageMaker. 1960 1961 Returns: 1962 If the endpoint exists, a dictionary of endpoint attributes. If the endpoint does not 1963 exist, ``None``. 1964 """ 1965 endpoints_page = sage_client.list_endpoints(MaxResults=100, NameContains=endpoint_name) 1966 1967 while True: 1968 for endpoint in endpoints_page["Endpoints"]: 1969 if endpoint["EndpointName"] == endpoint_name: 1970 return endpoint 1971 1972 if "NextToken" in endpoints_page: 1973 endpoints_page = sage_client.list_endpoints( 1974 MaxResults=100, NextToken=endpoints_page["NextToken"], NameContains=endpoint_name 1975 ) 1976 else: 1977 return None 1978 1979 1980 def _find_transform_job(job_name, sage_client): 1981 """ 1982 Finds a SageMaker batch transform job with the specified name in the caller's AWS account, 1983 returning a NoneType if the transform job is not found. 1984 1985 Args: 1986 sage_client: A boto3 client for SageMaker. 1987 1988 Returns: 1989 If the transform job exists, a dictionary of transform job attributes. If the 1990 transform job does not exist, ``None``. 1991 """ 1992 transform_jobs_page = sage_client.list_transform_jobs(MaxResults=100, NameContains=job_name) 1993 1994 while True: 1995 for transform_job in transform_jobs_page["TransformJobSummaries"]: 1996 if transform_job["TransformJobName"] == job_name: 1997 return transform_job 1998 1999 if "NextToken" in transform_jobs_page: 2000 transform_jobs_page = sage_client.list_transform_jobs( 2001 MaxResults=100, 2002 NextToken=transform_jobs_page["NextToken"], 2003 NameContains=job_name, 2004 ) 2005 else: 2006 return None 2007 2008 2009 def _does_model_exist(model_name, sage_client): 2010 """ 2011 Determines whether a SageMaker model exists with the specified name in the caller's AWS account, 2012 returning True if the model exists, returning False if the model does not exist. 2013 2014 Args: 2015 sage_client: A boto3 client for SageMaker. 2016 2017 Returns: 2018 If the model exists, ``True``. If the model does not 2019 exist, ``False``. 2020 """ 2021 try: 2022 response = sage_client.describe_model(ModelName=model_name) 2023 except sage_client.exceptions.ClientError as error: 2024 if "Could not find model" in error.response["Error"]["Message"]: 2025 return False 2026 else: 2027 return bool(response) 2028 2029 2030 class SageMakerDeploymentClient(BaseDeploymentClient): 2031 """ 2032 Initialize a deployment client for SageMaker. The default region and assumed role ARN will 2033 be set according to the value of the `target_uri`. 2034 2035 This class is meant to supersede the other ``mlflow.sagemaker`` real-time serving API's. 2036 It is also designed to be used through the :py:mod:`mlflow.deployments` module. 2037 This means that you can deploy to SageMaker using the 2038 `mlflow deployments CLI <https://www.mlflow.org/docs/latest/cli.html#mlflow-deployments>`_ and 2039 get a client through the :py:mod:`mlflow.deployments.get_deploy_client` function. 2040 2041 Args: 2042 target_uri: A URI that follows one of the following formats: 2043 2044 - ``sagemaker``: This will set the default region to `us-west-2` and 2045 the default assumed role ARN to `None`. 2046 - ``sagemaker:/region_name``: This will set the default region to 2047 `region_name` and the default assumed role ARN to `None`. 2048 - ``sagemaker:/region_name/assumed_role_arn``: This will set the default 2049 region to `region_name` and the default assumed role ARN to 2050 `assumed_role_arn`. 2051 2052 When an `assumed_role_arn` is provided without a `region_name`, 2053 an MlflowException will be raised. 2054 """ 2055 2056 def __init__(self, target_uri): 2057 super().__init__(target_uri=target_uri) 2058 2059 # Default region_name and assumed_role_arn when 2060 # the target_uri is `sagemaker` or `sagemaker:/` 2061 self.region_name = DEFAULT_REGION_NAME 2062 self.assumed_role_arn = None 2063 self._get_values_from_target_uri() 2064 2065 def _get_values_from_target_uri(self): 2066 parsed = urllib.parse.urlparse(self.target_uri) 2067 values_str = parsed.path.strip("/") 2068 2069 if not parsed.scheme or not values_str: 2070 return 2071 2072 separator_index = values_str.find("/") 2073 if separator_index == -1: 2074 # values_str would look like us-east-1 2075 self.region_name = values_str 2076 else: 2077 # values_str could look like us-east-1/arn:aws:1234:role/assumed_role 2078 self.region_name = values_str[:separator_index] 2079 self.assumed_role_arn = values_str[separator_index + 1 :] 2080 2081 # if values_str contains multiple interior slashes such as 2082 # us-east-1/////arn:aws:1234:role/assumed_role, remove 2083 # the extra slashes that come before "arn" 2084 self.assumed_role_arn = self.assumed_role_arn.strip("/") 2085 2086 if self.region_name.startswith("arn"): 2087 raise MlflowException( 2088 message=( 2089 "It looks like the target_uri contains an IAM role ARN without a region name.\n" 2090 "A region name must be provided when the target_uri contains a role ARN.\n" 2091 "In this case, the target_uri must follow the format: " 2092 "sagemaker:/region_name/assumed_role_arn.\n" 2093 f"The provided target_uri is: {self.target_uri}\n" 2094 ), 2095 error_code=INVALID_PARAMETER_VALUE, 2096 ) 2097 2098 def _default_deployment_config(self, create_mode=True): 2099 config = { 2100 "assume_role_arn": self.assumed_role_arn, 2101 "execution_role_arn": None, 2102 "bucket": None, 2103 "image_url": None, 2104 "region_name": self.region_name, 2105 "archive": False, 2106 "instance_type": DEFAULT_SAGEMAKER_INSTANCE_TYPE, 2107 "instance_count": DEFAULT_SAGEMAKER_INSTANCE_COUNT, 2108 "vpc_config": None, 2109 "data_capture_config": None, 2110 "synchronous": True, 2111 "timeout_seconds": 1200, 2112 "variant_name": None, 2113 "env": None, 2114 "tags": None, 2115 "async_inference_config": {}, 2116 "serverless_config": {}, 2117 } 2118 2119 if create_mode: 2120 config["mode"] = DEPLOYMENT_MODE_CREATE 2121 else: 2122 config["mode"] = DEPLOYMENT_MODE_REPLACE 2123 2124 return config 2125 2126 def _apply_custom_config(self, config, custom_config): 2127 int_fields = {"instance_count", "timeout_seconds"} 2128 bool_fields = {"synchronous", "archive"} 2129 dict_fields = { 2130 "vpc_config", 2131 "data_capture_config", 2132 "tags", 2133 "env", 2134 "async_inference_config", 2135 "serverless_config", 2136 } 2137 for key, value in custom_config.items(): 2138 if key not in config: 2139 continue 2140 2141 if key in int_fields and not isinstance(value, int): 2142 value = int(value) 2143 elif key in bool_fields and not isinstance(value, bool): 2144 value = value == "True" 2145 elif key in dict_fields and not isinstance(value, dict): 2146 value = json.loads(value) 2147 2148 config[key] = value 2149 2150 def create_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None): 2151 """ 2152 Deploy an MLflow model on AWS SageMaker. 2153 The currently active AWS account must have correct permissions set up. 2154 2155 This function creates a SageMaker endpoint. For more information about the input data 2156 formats accepted by this endpoint, see the 2157 `MLflow deployment tools documentation <../../deployment/deploy-model-to-sagemaker.html>`_. 2158 2159 Args: 2160 name: Name of the deployed application. 2161 model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker. 2162 For example: 2163 2164 - ``/Users/me/path/to/local/model`` 2165 - ``relative/path/to/local/model`` 2166 - ``s3://my_bucket/path/to/model`` 2167 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 2168 - ``models:/<model_name>/<model_version>`` 2169 - ``models:/<model_name>/<stage>`` 2170 2171 For more information about supported URI schemes, see 2172 `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html# 2173 artifact-locations>`_. 2174 flavor: The name of the flavor of the model to use for deployment. Must be either 2175 ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS. 2176 If ``None``, a flavor is automatically selected from the model's available 2177 flavors. If the specified flavor is not present or not supported for 2178 deployment, an exception will be thrown. 2179 config: Configuration parameters. The supported parameters are: 2180 2181 - ``assume_role_arn``: The name of an IAM cross-account role to be assumed 2182 to deploy SageMaker to another AWS account. If this parameter is not 2183 specified, the role given in the ``target_uri`` will be used. If the 2184 role is not given in the ``target_uri``, defaults to ``us-west-2``. 2185 2186 - ``execution_role_arn``: The name of an IAM role granting the SageMaker 2187 service permissions to access the specified Docker image and S3 bucket 2188 containing MLflow model artifacts. If unspecified, the currently-assumed 2189 role will be used. This execution role is passed to the SageMaker service 2190 when creating a SageMaker model from the specified MLflow model. It is 2191 passed as the ``ExecutionRoleArn`` parameter of the `SageMaker 2192 CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/ 2193 dg/API_CreateModel.html>`_. This role is *not* assumed for any other 2194 call. For more information about SageMaker execution roles for model 2195 creation, see 2196 https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html. 2197 2198 - ``bucket``: S3 bucket where model artifacts will be stored. Defaults to a 2199 SageMaker-compatible bucket name. 2200 2201 - ``image_url``: URL of the ECR-hosted Docker image the model should be 2202 deployed into, produced by ``mlflow sagemaker build-and-push-container``. 2203 This parameter can also be specified by the environment variable 2204 ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``. 2205 2206 - ``region_name``: Name of the AWS region to which to deploy the application. 2207 If unspecified, use the region name given in the ``target_uri``. 2208 If it is also not specified in the ``target_uri``, 2209 defaults to ``us-west-2``. 2210 2211 - ``archive``: If ``True``, any pre-existing SageMaker application resources 2212 that become inactive (i.e. as a result of deploying in 2213 ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE`` mode) are preserved. 2214 These resources may include unused SageMaker models and endpoint 2215 configurations that were associated with a prior version of the 2216 application endpoint. If ``False``, these resources are deleted. 2217 In order to use ``archive=False``, ``create_deployment()`` must be executed 2218 synchronously with ``synchronous=True``. Defaults to ``False``. 2219 2220 - ``instance_type``: The type of SageMaker ML instance on which to deploy the 2221 model. For a list of supported instance types, see 2222 https://aws.amazon.com/sagemaker/pricing/instance-types/. 2223 Defaults to ``ml.m4.xlarge``. 2224 2225 - ``instance_count``: The number of SageMaker ML instances on which to deploy 2226 the model. Defaults to ``1``. 2227 2228 - ``synchronous``: If ``True``, this function will block until the deployment 2229 process succeeds or encounters an irrecoverable failure. If ``False``, 2230 this function will return immediately after starting the deployment 2231 process. It will not wait for the deployment process to complete; 2232 in this case, the caller is responsible for monitoring the health and 2233 status of the pending deployment via native SageMaker APIs or the AWS 2234 console. Defaults to ``True``. 2235 2236 - ``timeout_seconds``: If ``synchronous`` is ``True``, the deployment process 2237 will return after the specified number of seconds if no definitive result 2238 (success or failure) is achieved. Once the function returns, the caller is 2239 responsible for monitoring the health and status of the pending 2240 deployment using native SageMaker APIs or the AWS console. If 2241 ``synchronous`` is ``False``, this parameter is ignored. 2242 Defaults to ``300``. 2243 2244 - ``vpc_config``: A dictionary specifying the VPC configuration to use when 2245 creating the new SageMaker model associated with this application. 2246 The acceptable values for this parameter are identical to those of the 2247 ``VpcConfig`` parameter in the `SageMaker boto3 client's create_model 2248 method <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html 2249 #SageMaker.Client.create_model>`_. For more information, see 2250 https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html. 2251 Defaults to ``None``. 2252 2253 - ``data_capture_config``: A dictionary specifying the data capture 2254 configuration to use when creating the new SageMaker model associated with 2255 this application. 2256 For more information, see 2257 https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html. 2258 Defaults to ``None``. 2259 2260 - ``variant_name``: A string specifying the desired name when creating a production 2261 variant. Defaults to ``None``. 2262 2263 - ``async_inference_config``: A dictionary specifying the 2264 async_inference_configuration 2265 2266 - ``serverless_config``: A dictionary specifying the serverless_configuration 2267 2268 - ``env``: A dictionary specifying environment variables as key-value 2269 pairs to be set for the deployed model. Defaults to ``None``. 2270 2271 - ``tags``: A dictionary of key-value pairs representing additional 2272 tags to be set for the deployed model. Defaults to ``None``. 2273 2274 endpoint: (optional) Endpoint to create the deployment under. Currently unsupported 2275 2276 .. code-block:: python 2277 :caption: Python example 2278 2279 from mlflow.deployments import get_deploy_client 2280 2281 vpc_config = { 2282 "SecurityGroupIds": [ 2283 "sg-123456abc", 2284 ], 2285 "Subnets": [ 2286 "subnet-123456abc", 2287 ], 2288 } 2289 config = dict( 2290 assume_role_arn="arn:aws:123:role/assumed_role", 2291 execution_role_arn="arn:aws:456:role/execution_role", 2292 bucket_name="my-s3-bucket", 2293 image_url="1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1", 2294 region_name="us-east-1", 2295 archive=False, 2296 instance_type="ml.m5.4xlarge", 2297 instance_count=1, 2298 synchronous=True, 2299 timeout_seconds=300, 2300 vpc_config=vpc_config, 2301 variant_name="prod-variant-1", 2302 env={"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": '"--timeout 60"'}, 2303 tags={"training_timestamp": "2022-11-01T05:12:26"}, 2304 ) 2305 client = get_deploy_client("sagemaker") 2306 client.create_deployment( 2307 "my-deployment", 2308 model_uri="/mlruns/0/abc/model", 2309 flavor="python_function", 2310 config=config, 2311 ) 2312 .. code-block:: bash 2313 :caption: Command-line example 2314 2315 mlflow deployments create --target sagemaker:/us-east-1/arn:aws:123:role/assumed_role \\ 2316 --name my-deployment \\ 2317 --model-uri /mlruns/0/abc/model \\ 2318 --flavor python_function\\ 2319 -C execution_role_arn=arn:aws:456:role/execution_role \\ 2320 -C bucket_name=my-s3-bucket \\ 2321 -C image_url=1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1 \\ 2322 -C region_name=us-east-1 \\ 2323 -C archive=False \\ 2324 -C instance_type=ml.m5.4xlarge \\ 2325 -C instance_count=1 \\ 2326 -C synchronous=True \\ 2327 -C timeout_seconds=300 \\ 2328 -C variant_name=prod-variant-1 \\ 2329 -C vpc_config='{"SecurityGroupIds": ["sg-123456abc"], \\ 2330 "Subnets": ["subnet-123456abc"]}' \\ 2331 -C data_capture_config='{"EnableCapture": True, \\ 2332 'InitialSamplingPercentage': 100, 'DestinationS3Uri": 's3://my-bucket/path', \\ 2333 'CaptureOptions': [{'CaptureMode': 'Output'}]}' 2334 -C env='{"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": "\"--timeout 60\""}' \\ 2335 -C tags='{"training_timestamp": "2022-11-01T05:12:26"}' \\ 2336 """ 2337 final_config = self._default_deployment_config() 2338 if config: 2339 self._apply_custom_config(final_config, config) 2340 2341 app_name, flavor = _deploy( 2342 app_name=name, 2343 model_uri=model_uri, 2344 flavor=flavor, 2345 execution_role_arn=final_config["execution_role_arn"], 2346 assume_role_arn=final_config["assume_role_arn"], 2347 bucket=final_config["bucket"], 2348 image_url=final_config["image_url"], 2349 region_name=final_config["region_name"], 2350 mode=mlflow.sagemaker.DEPLOYMENT_MODE_CREATE, 2351 archive=final_config["archive"], 2352 instance_type=final_config["instance_type"], 2353 instance_count=final_config["instance_count"], 2354 vpc_config=final_config["vpc_config"], 2355 data_capture_config=final_config["data_capture_config"], 2356 synchronous=final_config["synchronous"], 2357 timeout_seconds=final_config["timeout_seconds"], 2358 variant_name=final_config["variant_name"], 2359 async_inference_config=final_config["async_inference_config"], 2360 serverless_config=final_config["serverless_config"], 2361 env=final_config["env"], 2362 tags=final_config["tags"], 2363 ) 2364 2365 return {"name": app_name, "flavor": flavor} 2366 2367 def update_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None): 2368 """ 2369 Update a deployment on AWS SageMaker. This function can replace or add a new model to 2370 an existing SageMaker endpoint. By default, this function replaces the existing model 2371 with the new one. The currently active AWS account must have correct permissions set up. 2372 2373 Args: 2374 name: Name of the deployed application. 2375 model_uri: The location, in URI format, of the MLflow model to deploy to SageMaker. 2376 For example: 2377 2378 - ``/Users/me/path/to/local/model`` 2379 - ``relative/path/to/local/model`` 2380 - ``s3://my_bucket/path/to/model`` 2381 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 2382 - ``models:/<model_name>/<model_version>`` 2383 - ``models:/<model_name>/<stage>`` 2384 2385 For more information about supported URI schemes, see 2386 `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html# 2387 artifact-locations>`_. 2388 2389 flavor: The name of the flavor of the model to use for deployment. Must be either 2390 ``None`` or one of mlflow.sagemaker.SUPPORTED_DEPLOYMENT_FLAVORS. 2391 If ``None``, a flavor is automatically selected from the model's available 2392 flavors. If the specified flavor is not present or not supported for 2393 deployment, an exception will be thrown. 2394 2395 config: Configuration parameters. The supported parameters are: 2396 2397 - ``assume_role_arn``: The name of an IAM cross-account role to be assumed 2398 to deploy SageMaker to another AWS account. If this parameter is not 2399 specified, the role given in the ``target_uri`` will be used. If the 2400 role is not given in the ``target_uri``, defaults to ``us-west-2``. 2401 2402 - ``execution_role_arn``: The name of an IAM role granting the SageMaker 2403 service permissions to access the specified Docker image and S3 bucket 2404 containing MLflow model artifacts. If unspecified, the currently-assumed 2405 role will be used. This execution role is passed to the SageMaker service 2406 when creating a SageMaker model from the specified MLflow model. It is 2407 passed as the ``ExecutionRoleArn`` parameter of the `SageMaker 2408 CreateModel API call <https://docs.aws.amazon.com/sagemaker/latest/ 2409 dg/API_CreateModel.html>`_. This role is *not* assumed for any other 2410 call. For more information about SageMaker execution roles for model 2411 creation, see 2412 https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html. 2413 2414 - ``bucket``: S3 bucket where model artifacts will be stored. Defaults to a 2415 SageMaker-compatible bucket name. 2416 2417 - ``image_url``: URL of the ECR-hosted Docker image the model should be 2418 deployed into, produced by ``mlflow sagemaker build-and-push-container``. 2419 This parameter can also be specified by the environment variable 2420 ``MLFLOW_SAGEMAKER_DEPLOY_IMG_URL``. 2421 2422 - ``region_name``: Name of the AWS region to which to deploy the application. 2423 If unspecified, use the region name given in the ``target_uri``. 2424 If it is also not specified in the ``target_uri``, 2425 defaults to ``us-west-2``. 2426 2427 - ``mode``: The mode in which to deploy the application. 2428 Must be one of the following: 2429 2430 ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE`` 2431 If an application of the specified name exists, its model(s) is 2432 replaced with the specified model. If no such application exists, 2433 it is created with the specified name and model. 2434 This is the default mode. 2435 2436 ``mlflow.sagemaker.DEPLOYMENT_MODE_ADD`` 2437 Add the specified model to a pre-existing application with the 2438 specified name, if one exists. If the application does not exist, 2439 a new application is created with the specified name and model. 2440 NOTE: If the application **already exists**, the specified model is 2441 added to the application's corresponding SageMaker endpoint with an 2442 initial weight of zero (0). To route traffic to the model, 2443 update the application's associated endpoint configuration using 2444 either the AWS console or the ``UpdateEndpointWeightsAndCapacities`` 2445 function defined in https://docs.aws.amazon.com/sagemaker/latest/dg/API_UpdateEndpointWeightsAndCapacities.html. 2446 2447 - ``archive``: If ``True``, any pre-existing SageMaker application resources 2448 that become inactive (i.e. as a result of deploying in 2449 ``mlflow.sagemaker.DEPLOYMENT_MODE_REPLACE`` mode) are preserved. 2450 These resources may include unused SageMaker models and endpoint 2451 configurations that were associated with a prior version of the 2452 application endpoint. If ``False``, these resources are deleted. 2453 In order to use ``archive=False``, ``update_deployment()`` must be executed 2454 synchronously with ``synchronous=True``. Defaults to ``False``. 2455 2456 - ``instance_type``: The type of SageMaker ML instance on which to deploy the 2457 model. For a list of supported instance types, see 2458 https://aws.amazon.com/sagemaker/pricing/instance-types/. 2459 Defaults to ``ml.m4.xlarge``. 2460 2461 - ``instance_count``: The number of SageMaker ML instances on which to deploy 2462 the model. Defaults to ``1``. 2463 2464 - ``synchronous``: If ``True``, this function will block until the deployment 2465 process succeeds or encounters an irrecoverable failure. If ``False``, 2466 this function will return immediately after starting the deployment 2467 process. It will not wait for the deployment process to complete; 2468 in this case, the caller is responsible for monitoring the health and 2469 status of the pending deployment via native SageMaker APIs or the AWS 2470 console. Defaults to ``True``. 2471 2472 - ``timeout_seconds``: If ``synchronous`` is ``True``, the deployment process 2473 will return after the specified number of seconds if no definitive result 2474 (success or failure) is achieved. Once the function returns, the caller is 2475 responsible for monitoring the health and status of the pending 2476 deployment using native SageMaker APIs or the AWS console. If 2477 ``synchronous`` is ``False``, this parameter is ignored. 2478 Defaults to ``300``. 2479 2480 - ``variant_name``: A string specifying the desired name when creating a 2481 production variant. Defaults to ``None``. 2482 2483 - ``vpc_config``: A dictionary specifying the VPC configuration to use when 2484 creating the new SageMaker model associated with this application. 2485 The acceptable values for this parameter are identical to those of the 2486 ``VpcConfig`` parameter in the `SageMaker boto3 client's create_model 2487 method <https://boto3.readthedocs.io/en/latest/reference/services/sagemaker.html 2488 #SageMaker.Client.create_model>`_. For more information, see 2489 https://docs.aws.amazon.com/sagemaker/latest/dg/API_VpcConfig.html. 2490 Defaults to ``None``. 2491 2492 - ``data_capture_config``: A dictionary specifying the data capture 2493 configuration to use when creating the new SageMaker model associated with 2494 this application. 2495 For more information, see 2496 https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DataCaptureConfig.html. 2497 Defaults to ``None``. 2498 2499 - ``async_inference_config``: A dictionary specifying the async config 2500 configuration. Defaults to ``None``. 2501 2502 - ``env``: A dictionary specifying environment variables as key-value pairs 2503 to be set for the deployed model. Defaults to ``None``. 2504 2505 - ``tags``: A dictionary of key-value pairs representing additional tags 2506 to be set for the deployed model. Defaults to ``None``. 2507 2508 endpoint: (optional) Endpoint containing the deployment to update. Currently unsupported 2509 2510 .. code-block:: python 2511 :caption: Python example 2512 2513 from mlflow.deployments import get_deploy_client 2514 2515 vpc_config = { 2516 "SecurityGroupIds": [ 2517 "sg-123456abc", 2518 ], 2519 "Subnets": [ 2520 "subnet-123456abc", 2521 ], 2522 } 2523 data_capture_config = { 2524 "EnableCapture": True, 2525 "InitialSamplingPercentage": 100, 2526 "DestinationS3Uri": "s3://my-bucket/path", 2527 "CaptureOptions": [{"CaptureMode": "Output"}], 2528 } 2529 config = dict( 2530 assume_role_arn="arn:aws:123:role/assumed_role", 2531 execution_role_arn="arn:aws:456:role/execution_role", 2532 bucket_name="my-s3-bucket", 2533 image_url="1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1", 2534 region_name="us-east-1", 2535 mode="replace", 2536 archive=False, 2537 instance_type="ml.m5.4xlarge", 2538 instance_count=1, 2539 synchronous=True, 2540 timeout_seconds=300, 2541 variant_name="prod-variant-1", 2542 vpc_config=vpc_config, 2543 data_capture_config=data_capture_config, 2544 env={"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": '"--timeout 60"'}, 2545 tags={"training_timestamp": "2022-11-01T05:12:26"}, 2546 ) 2547 client = get_deploy_client("sagemaker") 2548 client.update_deployment( 2549 "my-deployment", 2550 model_uri="/mlruns/0/abc/model", 2551 flavor="python_function", 2552 config=config, 2553 ) 2554 .. code-block:: bash 2555 :caption: Command-line example 2556 2557 mlflow deployments update --target sagemaker:/us-east-1/arn:aws:123:role/assumed_role \\ 2558 --name my-deployment \\ 2559 --model-uri /mlruns/0/abc/model \\ 2560 --flavor python_function\\ 2561 -C execution_role_arn=arn:aws:456:role/execution_role \\ 2562 -C bucket_name=my-s3-bucket \\ 2563 -C image_url=1234.dkr.ecr.us-east-1.amazonaws.com/mlflow-test:1.23.1 \\ 2564 -C region_name=us-east-1 \\ 2565 -C mode=replace \\ 2566 -C archive=False \\ 2567 -C instance_type=ml.m5.4xlarge \\ 2568 -C instance_count=1 \\ 2569 -C synchronous=True \\ 2570 -C timeout_seconds=300 \\ 2571 -C variant_name=prod-variant-1 \\ 2572 -C vpc_config='{"SecurityGroupIds": ["sg-123456abc"], \\ 2573 "Subnets": ["subnet-123456abc"]}' \\ 2574 -C data_capture_config='{"EnableCapture": True, \\ 2575 "InitialSamplingPercentage": 100, "DestinationS3Uri": "s3://my-bucket/path", \\ 2576 "CaptureOptions": [{"CaptureMode": "Output"}]}' 2577 -C env='{"DISABLE_NGINX": "true", "GUNICORN_CMD_ARGS": "\"--timeout 60\""}' \\ 2578 -C tags='{"training_timestamp": "2022-11-01T05:12:26"}' \\ 2579 """ 2580 final_config = self._default_deployment_config(create_mode=False) 2581 if config: 2582 self._apply_custom_config(final_config, config) 2583 2584 if model_uri is None: 2585 raise MlflowException( 2586 message="A model_uri must be provided when updating a SageMaker deployment", 2587 error_code=INVALID_PARAMETER_VALUE, 2588 ) 2589 2590 if final_config["mode"] not in [DEPLOYMENT_MODE_ADD, DEPLOYMENT_MODE_REPLACE]: 2591 raise MlflowException( 2592 message=( 2593 f"Invalid mode `{final_config['mode']}` for deployment" 2594 " to a pre-existing application" 2595 ), 2596 error_code=INVALID_PARAMETER_VALUE, 2597 ) 2598 2599 app_name, flavor = _deploy( 2600 app_name=name, 2601 model_uri=model_uri, 2602 flavor=flavor, 2603 execution_role_arn=final_config["execution_role_arn"], 2604 assume_role_arn=final_config["assume_role_arn"], 2605 bucket=final_config["bucket"], 2606 image_url=final_config["image_url"], 2607 region_name=final_config["region_name"], 2608 mode=final_config["mode"], 2609 archive=final_config["archive"], 2610 instance_type=final_config["instance_type"], 2611 instance_count=final_config["instance_count"], 2612 vpc_config=final_config["vpc_config"], 2613 data_capture_config=final_config["data_capture_config"], 2614 synchronous=final_config["synchronous"], 2615 timeout_seconds=final_config["timeout_seconds"], 2616 variant_name=final_config["variant_name"], 2617 async_inference_config=final_config["async_inference_config"], 2618 serverless_config=final_config["serverless_config"], 2619 env=final_config["env"], 2620 tags=final_config["tags"], 2621 ) 2622 2623 return {"name": app_name, "flavor": flavor} 2624 2625 def delete_deployment(self, name, config=None, endpoint=None): 2626 """ 2627 Delete a SageMaker application. 2628 2629 Args: 2630 name: Name of the deployed application. 2631 config: Configuration parameters. The supported parameters are: 2632 2633 - ``assume_role_arn``: The name of an IAM role to be assumed to delete 2634 the SageMaker deployment. 2635 2636 - ``region_name``: Name of the AWS region in which the application 2637 is deployed. Defaults to ``us-west-2`` or the region provided in 2638 the `target_uri`. 2639 2640 - ``archive``: If `True`, resources associated with the specified 2641 application, such as its associated models and endpoint configuration, 2642 are preserved. If `False`, these resources are deleted. In order to use 2643 ``archive=False``, ``delete()`` must be executed synchronously with 2644 ``synchronous=True``. Defaults to ``False``. 2645 2646 - ``synchronous``: If `True`, this function blocks until the deletion process 2647 succeeds or encounters an irrecoverable failure. If `False`, this function 2648 returns immediately after starting the deletion process. It will not wait 2649 for the deletion process to complete; in this case, the caller is 2650 responsible for monitoring the status of the deletion process via native 2651 SageMaker APIs or the AWS console. Defaults to ``True``. 2652 2653 - ``timeout_seconds``: If `synchronous` is `True`, the deletion process 2654 returns after the specified number of seconds if no definitive result 2655 (success or failure) is achieved. Once the function returns, the caller 2656 is responsible for monitoring the status of the deletion process via native 2657 SageMaker APIs or the AWS console. If `synchronous` is False, this 2658 parameter is ignored. Defaults to ``300``. 2659 2660 endpoint: (optional) Endpoint containing the deployment to delete. Currently unsupported 2661 2662 .. code-block:: python 2663 :caption: Python example 2664 2665 from mlflow.deployments import get_deploy_client 2666 2667 config = dict( 2668 assume_role_arn="arn:aws:123:role/assumed_role", 2669 region_name="us-east-1", 2670 archive=False, 2671 synchronous=True, 2672 timeout_seconds=300, 2673 ) 2674 client = get_deploy_client("sagemaker") 2675 client.delete_deployment("my-deployment", config=config) 2676 2677 .. code-block:: bash 2678 :caption: Command-line example 2679 2680 mlflow deployments delete --target sagemaker \\ 2681 --name my-deployment \\ 2682 -C assume_role_arn=arn:aws:123:role/assumed_role \\ 2683 -C region_name=us-east-1 \\ 2684 -C archive=False \\ 2685 -C synchronous=True \\ 2686 -C timeout_seconds=300 2687 """ 2688 final_config = { 2689 "region_name": self.region_name, 2690 "archive": False, 2691 "synchronous": True, 2692 "timeout_seconds": 300, 2693 "assume_role_arn": self.assumed_role_arn, 2694 } 2695 if config: 2696 self._apply_custom_config(final_config, config) 2697 2698 _delete( 2699 name, 2700 region_name=final_config["region_name"], 2701 assume_role_arn=final_config["assume_role_arn"], 2702 archive=final_config["archive"], 2703 synchronous=final_config["synchronous"], 2704 timeout_seconds=final_config["timeout_seconds"], 2705 ) 2706 2707 def list_deployments(self, endpoint=None): 2708 """ 2709 List deployments. This method returns a list of dictionaries that describes each deployment. 2710 2711 If a region name needs to be specified, the plugin must be initialized 2712 with the AWS region in the ``target_uri`` such as ``sagemaker:/us-east-1``. 2713 2714 To assume an IAM role, the plugin must be initialized 2715 with the AWS region and the role ARN in the ``target_uri`` such as 2716 ``sagemaker:/us-east-1/arn:aws:1234:role/assumed_role``. 2717 2718 Args: 2719 endpoint: (optional) List deployments in the specified endpoint. Currently unsupported 2720 2721 Returns: 2722 A list of dictionaries corresponding to deployments. 2723 2724 .. code-block:: python 2725 :caption: Python example 2726 2727 from mlflow.deployments import get_deploy_client 2728 2729 client = get_deploy_client("sagemaker:/us-east-1/arn:aws:123:role/assumed_role") 2730 client.list_deployments() 2731 2732 .. code-block:: bash 2733 :caption: Command-line example 2734 2735 mlflow deployments list --target sagemaker:/us-east-1/arn:aws:1234:role/assumed_role 2736 """ 2737 import boto3 2738 2739 assume_role_credentials = _assume_role_and_get_credentials( 2740 assume_role_arn=self.assumed_role_arn 2741 ) 2742 2743 sage_client = boto3.client( 2744 "sagemaker", region_name=self.region_name, **assume_role_credentials 2745 ) 2746 return sage_client.list_endpoints()["Endpoints"] 2747 2748 def get_deployment(self, name, endpoint=None): 2749 """ 2750 Returns a dictionary describing the specified deployment. 2751 2752 If a region name needs to be specified, the plugin must be initialized 2753 with the AWS region in the ``target_uri`` such as ``sagemaker:/us-east-1``. 2754 2755 To assume an IAM role, the plugin must be initialized 2756 with the AWS region and the role ARN in the ``target_uri`` such as 2757 ``sagemaker:/us-east-1/arn:aws:1234:role/assumed_role``. 2758 2759 A :py:class:`mlflow.exceptions.MlflowException` will also be thrown when an error occurs 2760 while retrieving the deployment. 2761 2762 Args: 2763 name: Name of deployment to retrieve 2764 endpoint: (optional) Endpoint containing the deployment to get. Currently unsupported 2765 2766 Returns: 2767 A dictionary that describes the specified deployment 2768 2769 .. code-block:: python 2770 :caption: Python example 2771 2772 from mlflow.deployments import get_deploy_client 2773 2774 client = get_deploy_client("sagemaker:/us-east-1/arn:aws:123:role/assumed_role") 2775 client.get_deployment("my-deployment") 2776 2777 .. code-block:: bash 2778 :caption: Command-line example 2779 2780 mlflow deployments get --target sagemaker:/us-east-1/arn:aws:1234:role/assumed_role \\ 2781 --name my-deployment 2782 """ 2783 import boto3 2784 2785 assume_role_credentials = _assume_role_and_get_credentials( 2786 assume_role_arn=self.assumed_role_arn 2787 ) 2788 2789 try: 2790 sage_client = boto3.client( 2791 "sagemaker", region_name=self.region_name, **assume_role_credentials 2792 ) 2793 return sage_client.describe_endpoint(EndpointName=name) 2794 except Exception as exc: 2795 raise MlflowException( 2796 message=f"There was an error while retrieving the deployment: {exc}\n" 2797 ) 2798 2799 def predict( 2800 self, 2801 deployment_name=None, 2802 inputs=None, 2803 endpoint=None, 2804 params: dict[str, Any] | None = None, 2805 ): 2806 """ 2807 Compute predictions from the specified deployment using the provided PyFunc input. 2808 2809 The input/output types of this method match the :ref:`MLflow PyFunc prediction 2810 interface <pyfunc-inference-api>`. 2811 2812 If a region name needs to be specified, the plugin must be initialized 2813 with the AWS region in the ``target_uri`` such as ``sagemaker:/us-east-1``. 2814 2815 To assume an IAM role, the plugin must be initialized 2816 with the AWS region and the role ARN in the ``target_uri`` such as 2817 ``sagemaker:/us-east-1/arn:aws:1234:role/assumed_role``. 2818 2819 Args: 2820 deployment_name: Name of the deployment to predict against. 2821 inputs: Input data (or arguments) to pass to the deployment or model endpoint for 2822 inference. For a complete list of supported input types, see 2823 :ref:`pyfunc-inference-api`. 2824 endpoint: Endpoint to predict against. Currently unsupported 2825 params: Optional parameters to invoke the endpoint with. 2826 2827 Returns: 2828 A PyFunc output, such as a Pandas DataFrame, Pandas Series, or NumPy array. 2829 For a complete list of supported output types, see :ref:`pyfunc-inference-api`. 2830 2831 .. code-block:: python 2832 :caption: Python example 2833 2834 import pandas as pd 2835 from mlflow.deployments import get_deploy_client 2836 2837 df = pd.DataFrame(data=[[1, 2, 3]], columns=["feat1", "feat2", "feat3"]) 2838 client = get_deploy_client("sagemaker:/us-east-1/arn:aws:123:role/assumed_role") 2839 client.predict("my-deployment", df) 2840 2841 .. code-block:: bash 2842 :caption: Command-line example 2843 2844 cat > ./input.json <<- input 2845 {"feat1": {"0": 1}, "feat2": {"0": 2}, "feat3": {"0": 3}} 2846 input 2847 2848 mlflow deployments predict \\ 2849 --target sagemaker:/us-east-1/arn:aws:1234:role/assumed_role \\ 2850 --name my-deployment \\ 2851 --input-path ./input.json 2852 """ 2853 import boto3 2854 2855 assume_role_credentials = _assume_role_and_get_credentials( 2856 assume_role_arn=self.assumed_role_arn 2857 ) 2858 2859 try: 2860 sage_client = boto3.client( 2861 "sagemaker-runtime", region_name=self.region_name, **assume_role_credentials 2862 ) 2863 response = sage_client.invoke_endpoint( 2864 EndpointName=deployment_name, 2865 Body=dump_input_data(inputs, inputs_key="instances", params=params), 2866 ContentType="application/json", 2867 ) 2868 response_body = response["Body"].read().decode("utf-8") 2869 return PredictionsResponse.from_json(response_body) 2870 except Exception as exc: 2871 raise MlflowException( 2872 message=f"There was an error while getting model prediction: {exc}\n" 2873 ) 2874 2875 def explain(self, deployment_name=None, df=None, endpoint=None): 2876 """ 2877 *This function has not been implemented and will be coming in the future.* 2878 """ 2879 raise NotImplementedError("This function is not implemented yet.") 2880 2881 def create_endpoint(self, name, config=None): 2882 """ 2883 Create an endpoint with the specified target. By default, this method should block until 2884 creation completes (i.e. until it's possible to create a deployment within the endpoint). 2885 In the case of conflicts (e.g. if it's not possible to create the specified endpoint 2886 due to conflict with an existing endpoint), raises a 2887 :py:class:`mlflow.exceptions.MlflowException`. See target-specific plugin documentation 2888 for additional detail on support for asynchronous creation and other configuration. 2889 2890 Args: 2891 name: Unique name to use for endpoint. If another endpoint exists with the same 2892 name, raises a :py:class:`mlflow.exceptions.MlflowException`. 2893 config: (optional) Dict containing target-specific configuration for the endpoint. 2894 2895 Returns: 2896 Dict corresponding to created endpoint, which must contain the 'name' key. 2897 """ 2898 raise NotImplementedError("This function is not implemented yet.") 2899 2900 def update_endpoint(self, endpoint, config=None): 2901 """ 2902 Update the endpoint with the specified name. You can update any target-specific attributes 2903 of the endpoint (via `config`). By default, this method should block until the update 2904 completes (i.e. until it's possible to create a deployment within the endpoint). See 2905 target-specific plugin documentation for additional detail on support for asynchronous 2906 update and other configuration. 2907 2908 Args: 2909 endpoint: Unique name of endpoint to update 2910 config: (optional) dict containing target-specific configuration for the endpoint 2911 """ 2912 raise NotImplementedError("This function is not implemented yet.") 2913 2914 def delete_endpoint(self, endpoint): 2915 """ 2916 Delete the endpoint from the specified target. Deletion should be idempotent (i.e. deletion 2917 should not fail if retried on a non-existent deployment). 2918 2919 Args: 2920 endpoint: Name of endpoint to delete 2921 """ 2922 raise NotImplementedError("This function is not implemented yet.") 2923 2924 def list_endpoints(self): 2925 """ 2926 List endpoints in the specified target. This method is expected to return an 2927 unpaginated list of all endpoints (an alternative would be to return a dict with 2928 an 'endpoints' field containing the actual endpoints, with plugins able to specify 2929 other fields, e.g. a next_page_token field, in the returned dictionary for pagination, 2930 and to accept a `pagination_args` argument to this method for passing 2931 pagination-related args). 2932 2933 Returns: 2934 A list of dicts corresponding to endpoints. Each dict is guaranteed to 2935 contain a 'name' key containing the endpoint name. The other fields of 2936 the returned dictionary and their types may vary across targets. 2937 """ 2938 raise NotImplementedError("This function is not implemented yet.") 2939 2940 def get_endpoint(self, endpoint): 2941 """ 2942 Returns a dictionary describing the specified endpoint, throwing a 2943 py:class:`mlflow.exception.MlflowException` if no endpoint exists with the provided 2944 name. 2945 The dict is guaranteed to contain an 'name' key containing the endpoint name. 2946 The other fields of the returned dictionary and their types may vary across targets. 2947 2948 Args: 2949 endpoint: Name of endpoint to fetch 2950 """ 2951 raise NotImplementedError("This function is not implemented yet.") 2952 2953 2954 class _SageMakerOperation: 2955 def __init__(self, status_check_fn, cleanup_fn): 2956 self.status_check_fn = status_check_fn 2957 self.cleanup_fn = cleanup_fn 2958 self.start_time = time.time() 2959 self.status = _SageMakerOperationStatus(_SageMakerOperationStatus.STATE_IN_PROGRESS, None) 2960 self.cleaned_up = False 2961 2962 def await_completion(self, timeout_seconds): 2963 iteration = 0 2964 begin = time.time() 2965 while (time.time() - begin) < timeout_seconds: 2966 status = self.status_check_fn() 2967 if status.state == _SageMakerOperationStatus.STATE_IN_PROGRESS: 2968 if iteration % 4 == 0: 2969 # Log the progress status roughly every 20 seconds 2970 _logger.info(status.message) 2971 2972 time.sleep(5) 2973 iteration += 1 2974 continue 2975 else: 2976 self.status = status 2977 return status 2978 2979 duration_seconds = time.time() - begin 2980 return _SageMakerOperationStatus.timed_out(duration_seconds) 2981 2982 def clean_up(self): 2983 if self.status.state != _SageMakerOperationStatus.STATE_SUCCEEDED: 2984 raise ValueError( 2985 "Cannot clean up an operation that has not succeeded! Current operation state:" 2986 f" {self.status.state}" 2987 ) 2988 2989 if not self.cleaned_up: 2990 self.cleaned_up = True 2991 else: 2992 raise ValueError("`clean_up()` has already been executed for this operation!") 2993 2994 self.cleanup_fn() 2995 2996 2997 class _SageMakerOperationStatus: 2998 STATE_SUCCEEDED = "succeeded" 2999 STATE_FAILED = "failed" 3000 STATE_IN_PROGRESS = "in progress" 3001 STATE_TIMED_OUT = "timed_out" 3002 3003 def __init__(self, state, message): 3004 self.state = state 3005 self.message = message 3006 3007 @classmethod 3008 def in_progress(cls, message=None): 3009 if message is None: 3010 message = "The operation is still in progress." 3011 return cls(_SageMakerOperationStatus.STATE_IN_PROGRESS, message) 3012 3013 @classmethod 3014 def timed_out(cls, duration_seconds): 3015 return cls( 3016 _SageMakerOperationStatus.STATE_TIMED_OUT, 3017 f"Timed out after waiting {duration_seconds} seconds for the operation to" 3018 " complete. This operation may still be in progress. Please check the AWS" 3019 " console for more information.", 3020 ) 3021 3022 @classmethod 3023 def failed(cls, message): 3024 return cls(_SageMakerOperationStatus.STATE_FAILED, message) 3025 3026 @classmethod 3027 def succeeded(cls, message): 3028 return cls(_SageMakerOperationStatus.STATE_SUCCEEDED, message)