diff --git a/.github/workflows/codebuild-ci.yml b/.github/workflows/codebuild-ci.yml
index 85919f0afe..8c6bd6b337 100644
--- a/.github/workflows/codebuild-ci.yml
+++ b/.github/workflows/codebuild-ci.yml
@@ -55,7 +55,7 @@ jobs:
- name: Run Codestyle & Doc Tests
uses: aws-actions/aws-codebuild-run-build@v1
with:
- project-name: sagemaker-python-sdk-ci-codestyle-doc-tests
+ project-name: ${{ github.event.repository.name }}-ci-codestyle-doc-tests
source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}'
unit-tests:
runs-on: ubuntu-latest
@@ -74,7 +74,7 @@ jobs:
- name: Run Unit Tests
uses: aws-actions/aws-codebuild-run-build@v1
with:
- project-name: sagemaker-python-sdk-ci-unit-tests
+ project-name: ${{ github.event.repository.name }}-ci-unit-tests
source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}'
env-vars-for-codebuild: |
PY_VERSION
@@ -93,5 +93,5 @@ jobs:
- name: Run Integ Tests
uses: aws-actions/aws-codebuild-run-build@v1
with:
- project-name: sagemaker-python-sdk-ci-integ-tests
+ project-name: ${{ github.event.repository.name }}-ci-integ-tests
source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}'
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 546c3438a0..63e5114f10 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,86 @@
# Changelog
+## v2.221.1 (2024-05-22)
+
+### Bug Fixes and Other Changes
+
+ * Convert pytorchddp distribution to smdistributed distribution
+ * Add tei cpu image
+
+## v2.221.0 (2024-05-20)
+
+### Features
+
+ * onboard tei image config to pysdk
+
+### Bug Fixes and Other Changes
+
+ * JS Model with non-TGI/non-DJL deployment failure
+ * cover tei with image_uris.retrieve API
+ * Add more debuging
+ * model builder limited container support for endpoint mode.
+ * Image URI should take precedence for HF models
+
+## v2.220.0 (2024-05-15)
+
+### Features
+
+ * AutoGluon 1.1.0 image_uris update
+ * add new images for HF TGI release
+ * Add telemetry support for mlflow models
+
+### Bug Fixes and Other Changes
+
+ * add debug logs to workflow container dist creation
+ * model builder race condition on sagemaker session
+ * Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models
+ * update image_uri_configs 05-09-2024 07:17:41 PST
+ * skip flakey tests pending investigation
+
+## v2.219.0 (2024-05-08)
+
+### Features
+
+ * allow choosing js payload by alias in private method
+
+### Bug Fixes and Other Changes
+
+ * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /requirements/extras
+ * chore(deps): bump tqdm from 4.66.2 to 4.66.3 in /tests/data/serve_resources/mlflow/pytorch
+ * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /doc
+ * Updates for SMP v2.3.1
+
+## v2.218.1 (2024-05-03)
+
+### Bug Fixes and Other Changes
+
+ * Fix UserAgent logging in Python SDK
+ * chore: release tgi 2.0.1
+ * chore: update skipped flaky tests
+
+## v2.218.0 (2024-05-01)
+
+### Features
+
+ * set default allow_pickle param to False
+
+### Bug Fixes and Other Changes
+
+ * properly close files in lineage queries and tests
+
+## v2.217.0 (2024-04-24)
+
+### Features
+
+ * support session tag chaining for training job
+
+### Bug Fixes and Other Changes
+
+ * Add Triton v24.03 URI
+ * mainline alt config parsing
+ * Fix tox installs
+ * Add PT 2.2 Graviton Inference DLC
+
## v2.216.1 (2024-04-22)
### Bug Fixes and Other Changes
diff --git a/VERSION b/VERSION
index 9558cc93a5..e55266069e 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-2.216.2.dev0
+2.221.2.dev0
diff --git a/doc/requirements.txt b/doc/requirements.txt
index a65e0e4050..8193dfa22a 100644
--- a/doc/requirements.txt
+++ b/doc/requirements.txt
@@ -2,6 +2,6 @@ sphinx==5.1.1
sphinx-rtd-theme==0.5.0
docutils==0.15.2
packaging==20.9
-jinja2==3.1.3
+jinja2==3.1.4
schema==0.7.5
accelerate>=0.24.1,<=0.27.0
diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt
index 43da930636..889ff72779 100644
--- a/requirements/extras/test_requirements.txt
+++ b/requirements/extras/test_requirements.txt
@@ -12,13 +12,13 @@ awslogs==0.14.0
black==24.3.0
stopit==1.1.2
# Update tox.ini to have correct version of airflow constraints file
-apache-airflow==2.9.0
+apache-airflow==2.9.1
apache-airflow-providers-amazon==7.2.1
attrs>=23.1.0,<24
fabric==2.6.0
requests==2.31.0
sagemaker-experiments==0.1.35
-Jinja2==3.1.3
+Jinja2==3.1.4
pyvis==0.2.1
pandas>=1.3.5,<1.5
scikit-learn==1.3.0
@@ -36,3 +36,4 @@ onnx>=1.15.0
nbformat>=5.9,<6
accelerate>=0.24.1,<=0.27.0
schema==0.7.5
+tensorflow>=2.1,<=2.16
diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py
index 78aa655e04..7541425868 100644
--- a/src/sagemaker/accept_types.py
+++ b/src/sagemaker/accept_types.py
@@ -77,6 +77,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> str:
"""Retrieves the default accept type for the model matching the given arguments.
@@ -98,6 +99,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default accept type to use for the model.
@@ -117,4 +119,5 @@ def retrieve_default(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
diff --git a/src/sagemaker/base_deserializers.py b/src/sagemaker/base_deserializers.py
index 7162e5274d..a152f0144d 100644
--- a/src/sagemaker/base_deserializers.py
+++ b/src/sagemaker/base_deserializers.py
@@ -196,14 +196,14 @@ class NumpyDeserializer(SimpleBaseDeserializer):
single array.
"""
- def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
+ def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=False):
"""Initialize a ``NumpyDeserializer`` instance.
Args:
dtype (str): The dtype of the data (default: None).
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
is expected from the inference endpoint (default: "application/x-npy").
- allow_pickle (bool): Allow loading pickled object arrays (default: True).
+ allow_pickle (bool): Allow loading pickled object arrays (default: False).
"""
super(NumpyDeserializer, self).__init__(accept=accept)
self.dtype = dtype
@@ -227,10 +227,21 @@ def deserialize(self, stream, content_type):
if content_type == "application/json":
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
if content_type == "application/x-npy":
- return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
+ try:
+ return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
+ except ValueError as ve:
+ raise ValueError(
+ "Please set the param allow_pickle=True \
+ to deserialize pickle objects in NumpyDeserializer"
+ ).with_traceback(ve.__traceback__)
if content_type == "application/x-npz":
try:
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
+ except ValueError as ve:
+ raise ValueError(
+ "Please set the param allow_pickle=True \
+ to deserialize pickle objectsin NumpyDeserializer"
+ ).with_traceback(ve.__traceback__)
finally:
stream.close()
finally:
diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py
index 46d0361f67..627feca0d6 100644
--- a/src/sagemaker/content_types.py
+++ b/src/sagemaker/content_types.py
@@ -77,6 +77,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> str:
"""Retrieves the default content type for the model matching the given arguments.
@@ -98,6 +99,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default content type to use for the model.
@@ -117,6 +119,7 @@ def retrieve_default(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py
index 1a4be43897..02e61149ec 100644
--- a/src/sagemaker/deserializers.py
+++ b/src/sagemaker/deserializers.py
@@ -97,6 +97,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> BaseDeserializer:
"""Retrieves the default deserializer for the model matching the given arguments.
@@ -118,6 +119,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
BaseDeserializer: The default deserializer to use for the model.
@@ -138,4 +140,5 @@ def retrieve_default(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
diff --git a/src/sagemaker/enums.py b/src/sagemaker/enums.py
index 5b4d0d6790..f02b275cbe 100644
--- a/src/sagemaker/enums.py
+++ b/src/sagemaker/enums.py
@@ -28,3 +28,15 @@ class EndpointType(Enum):
INFERENCE_COMPONENT_BASED = (
"InferenceComponentBased" # Amazon SageMaker Inference Component Based Endpoint
)
+
+
+class RoutingStrategy(Enum):
+ """Strategy for routing https traffics."""
+
+ RANDOM = "RANDOM"
+ """The endpoint routes each request to a randomly chosen instance.
+ """
+ LEAST_OUTSTANDING_REQUESTS = "LEAST_OUTSTANDING_REQUESTS"
+ """The endpoint routes requests to the specific instances that have
+ more capacity to process them.
+ """
diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py
index b67066fcde..8fa52c3ec8 100644
--- a/src/sagemaker/environment_variables.py
+++ b/src/sagemaker/environment_variables.py
@@ -36,6 +36,7 @@ def retrieve_default(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
+ config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the default container environment variables for the model matching the arguments.
@@ -65,6 +66,7 @@ def retrieve_default(
variables specific for the instance type.
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
variables.
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: The variables to use for the model.
@@ -87,4 +89,5 @@ def retrieve_default(
sagemaker_session=sagemaker_session,
instance_type=instance_type,
script=script,
+ config_name=config_name,
)
diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py
index 066846564e..58a5fabc2f 100644
--- a/src/sagemaker/estimator.py
+++ b/src/sagemaker/estimator.py
@@ -181,6 +181,7 @@ def __init__(
container_arguments: Optional[List[str]] = None,
disable_output_compression: bool = False,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
+ enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
@@ -544,7 +545,9 @@ def __init__(
enable_infra_check (bool or PipelineVariable): Optional.
Specifies whether it is running Sagemaker built-in infra check jobs.
enable_remote_debug (bool or PipelineVariable): Optional.
- Specifies whether RemoteDebug is enabled for the training job
+ Specifies whether RemoteDebug is enabled for the training job.
+ enable_session_tag_chaining (bool or PipelineVariable): Optional.
+ Specifies whether SessionTagChaining is enabled for the training job.
"""
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", instance_count, kwargs
@@ -785,6 +788,8 @@ def __init__(
self._enable_remote_debug = enable_remote_debug
+ self._enable_session_tag_chaining = enable_session_tag_chaining
+
@abstractmethod
def training_image_uri(self):
"""Return the Docker image to use for training.
@@ -2318,6 +2323,14 @@ def get_remote_debug_config(self):
else {"EnableRemoteDebug": self._enable_remote_debug}
)
+ def get_session_chaining_config(self):
+ """dict: Return the configuration of SessionChaining"""
+ return (
+ None
+ if self._enable_session_tag_chaining is None
+ else {"EnableSessionTagChaining": self._enable_session_tag_chaining}
+ )
+
def enable_remote_debug(self):
"""Enable remote debug for a training job."""
self._update_remote_debug(True)
@@ -2574,6 +2587,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
if estimator.get_remote_debug_config() is not None:
train_args["remote_debug_config"] = estimator.get_remote_debug_config()
+ if estimator.get_session_chaining_config() is not None:
+ train_args["session_chaining_config"] = estimator.get_session_chaining_config()
+
return train_args
@classmethod
@@ -2766,6 +2782,7 @@ def __init__(
disable_output_compression: bool = False,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
+ enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
@@ -3129,6 +3146,8 @@ def __init__(
Specifies whether it is running Sagemaker built-in infra check jobs.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
+ enable_session_tag_chaining (bool or PipelineVariable): Optional.
+ Specifies whether SessionTagChaining is enabled for the training job
"""
self.image_uri = image_uri
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
@@ -3181,6 +3200,7 @@ def __init__(
container_arguments=container_arguments,
disable_output_compression=disable_output_compression,
enable_remote_debug=enable_remote_debug,
+ enable_session_tag_chaining=enable_session_tag_chaining,
**kwargs,
)
diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py
index cf9291a139..be3658365a 100644
--- a/src/sagemaker/fw_utils.py
+++ b/src/sagemaker/fw_utils.py
@@ -142,26 +142,9 @@
"2.1.0",
"2.1.2",
"2.2.0",
- "2.3.0",
],
}
-PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [
- "1.10",
- "1.10.0",
- "1.10.2",
- "1.11",
- "1.11.0",
- "1.12",
- "1.12.0",
- "1.12.1",
- "1.13.1",
- "2.0.0",
- "2.0.1",
- "2.1.0",
- "2.2.0",
-]
-
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
"1.13.1",
"2.0.0",
@@ -170,6 +153,7 @@
"2.1.2",
"2.2.0",
"2.3.0",
+ "2.3.1",
]
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
@@ -795,7 +779,6 @@ def _validate_smdataparallel_args(
Raises:
ValueError: if
- (`instance_type` is not in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES or
`py_version` is not python3 or
`framework_version` is not in SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSION
"""
@@ -806,17 +789,10 @@ def _validate_smdataparallel_args(
if not smdataparallel_enabled:
return
- is_instance_type_supported = instance_type in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES
-
err_msg = ""
- if not is_instance_type_supported:
- # instance_type is required
- err_msg += (
- f"Provided instance_type {instance_type} is not supported by smdataparallel.\n"
- "Please specify one of the supported instance types:"
- f"{SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES}\n"
- )
+ if not instance_type:
+ err_msg += "Please specify an instance_type for smdataparallel.\n"
if not image_uri:
# ignore framework_version & py_version if image_uri is set
@@ -928,13 +904,6 @@ def validate_distribution(
)
if framework_name and framework_name == "pytorch":
# We need to validate only for PyTorch framework
- validate_pytorch_distribution(
- distribution=validated_distribution,
- framework_name=framework_name,
- framework_version=framework_version,
- py_version=py_version,
- image_uri=image_uri,
- )
validate_torch_distributed_distribution(
instance_type=instance_type,
distribution=validated_distribution,
@@ -968,13 +937,6 @@ def validate_distribution(
)
if framework_name and framework_name == "pytorch":
# We need to validate only for PyTorch framework
- validate_pytorch_distribution(
- distribution=validated_distribution,
- framework_name=framework_name,
- framework_version=framework_version,
- py_version=py_version,
- image_uri=image_uri,
- )
validate_torch_distributed_distribution(
instance_type=instance_type,
distribution=validated_distribution,
@@ -1023,63 +985,6 @@ def validate_distribution_for_instance_type(instance_type, distribution):
raise ValueError(err_msg)
-def validate_pytorch_distribution(
- distribution, framework_name, framework_version, py_version, image_uri
-):
- """Check if pytorch distribution strategy is correctly invoked by the user.
-
- Args:
- distribution (dict): A dictionary with information to enable distributed training.
- (Defaults to None if distributed training is not enabled.) For example:
-
- .. code:: python
-
- {
- "pytorchddp": {
- "enabled": True
- }
- }
- framework_name (str): A string representing the name of framework selected.
- framework_version (str): A string representing the framework version selected.
- py_version (str): A string representing the python version selected.
- image_uri (str): A string representing a Docker image URI.
-
- Raises:
- ValueError: if
- `py_version` is not python3 or
- `framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS
- """
- if framework_name and framework_name != "pytorch":
- # We need to validate only for PyTorch framework
- return
-
- pytorch_ddp_enabled = False
- if "pytorchddp" in distribution:
- pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
- if not pytorch_ddp_enabled:
- # Distribution strategy other than pytorchddp is selected
- return
-
- err_msg = ""
- if not image_uri:
- # ignore framework_version and py_version if image_uri is set
- # in case image_uri is not set, then both are mandatory
- if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS:
- err_msg += (
- f"Provided framework_version {framework_version} is not supported by"
- " pytorchddp.\n"
- "Please specify one of the supported framework versions:"
- f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n"
- )
- if "py3" not in py_version:
- err_msg += (
- f"Provided py_version {py_version} is not supported by pytorchddp.\n"
- "Please specify py_version>=py3"
- )
- if err_msg:
- raise ValueError(err_msg)
-
-
def validate_torch_distributed_distribution(
instance_type,
distribution,
diff --git a/src/sagemaker/huggingface/llm_utils.py b/src/sagemaker/huggingface/llm_utils.py
index de5e624dbc..9927d1d293 100644
--- a/src/sagemaker/huggingface/llm_utils.py
+++ b/src/sagemaker/huggingface/llm_utils.py
@@ -65,6 +65,20 @@ def get_huggingface_llm_image_uri(
image_scope="inference",
inference_tool="neuronx",
)
+ if backend == "huggingface-tei":
+ return image_uris.retrieve(
+ "huggingface-tei",
+ region=region,
+ version=version,
+ image_scope="inference",
+ )
+ if backend == "huggingface-tei-cpu":
+ return image_uris.retrieve(
+ "huggingface-tei-cpu",
+ region=region,
+ version=version,
+ image_scope="inference",
+ )
if backend == "lmi":
version = version or "0.24.0"
return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version)
diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py
index f71dca0ac8..662baecae6 100644
--- a/src/sagemaker/huggingface/model.py
+++ b/src/sagemaker/huggingface/model.py
@@ -334,6 +334,7 @@ def deploy(
endpoint_type=kwargs.get("endpoint_type", None),
resources=kwargs.get("resources", None),
managed_instance_scaling=kwargs.get("managed_instance_scaling", None),
+ routing_config=kwargs.get("routing_config", None),
)
def register(
diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py
index 5873e37b9f..5c22409c50 100644
--- a/src/sagemaker/hyperparameters.py
+++ b/src/sagemaker/hyperparameters.py
@@ -36,6 +36,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the default training hyperparameters for the model matching the given arguments.
@@ -66,6 +67,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: The hyperparameters to use for the model.
@@ -86,6 +88,7 @@ def retrieve_default(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json
index 57ce47f94c..1ea6441774 100644
--- a/src/sagemaker/image_uri_config/autogluon.json
+++ b/src/sagemaker/image_uri_config/autogluon.json
@@ -11,7 +11,8 @@
"0.6": "0.6.2",
"0.7": "0.7.0",
"0.8": "0.8.2",
- "1.0": "1.0.0"
+ "1.0": "1.0.0",
+ "1.1": "1.1.0"
},
"versions": {
"0.3.1": {
@@ -480,6 +481,47 @@
"py_versions": [
"py310"
]
+ },
+ "1.1.0": {
+ "registries": {
+ "af-south-1": "626614931356",
+ "il-central-1": "780543022126",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ap-southeast-4": "457447274322",
+ "ca-central-1": "763104351884",
+ "eu-central-1": "763104351884",
+ "eu-north-1": "763104351884",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "eu-south-1": "692866216735",
+ "me-south-1": "217643126080",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
+ "us-iso-east-1": "886529160074",
+ "us-isob-east-1": "094389454867",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884",
+ "ca-west-1": "204538143572"
+ },
+ "repository": "autogluon-training",
+ "processors": [
+ "cpu",
+ "gpu"
+ ],
+ "py_versions": [
+ "py310"
+ ]
}
}
},
@@ -491,7 +533,8 @@
"0.6": "0.6.2",
"0.7": "0.7.0",
"0.8": "0.8.2",
- "1.0": "1.0.0"
+ "1.0": "1.0.0",
+ "1.1": "1.1.0"
},
"versions": {
"0.3.1": {
@@ -987,6 +1030,49 @@
"py_versions": [
"py310"
]
+ },
+ "1.1.0": {
+ "registries": {
+ "af-south-1": "626614931356",
+ "il-central-1": "780543022126",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ap-southeast-4": "457447274322",
+ "ca-central-1": "763104351884",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-north-1": "763104351884",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "eu-south-1": "692866216735",
+ "me-south-1": "217643126080",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
+ "us-iso-east-1": "886529160074",
+ "us-isob-east-1": "094389454867",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884",
+ "ca-west-1": "204538143572"
+ },
+ "repository": "autogluon-inference",
+ "processors": [
+ "cpu",
+ "gpu"
+ ],
+ "py_versions": [
+ "py310"
+ ]
}
}
}
diff --git a/src/sagemaker/image_uri_config/djl-lmi.json b/src/sagemaker/image_uri_config/djl-lmi.json
new file mode 100644
index 0000000000..9d2cdd699a
--- /dev/null
+++ b/src/sagemaker/image_uri_config/djl-lmi.json
@@ -0,0 +1,39 @@
+{
+ "scope": [
+ "inference"
+ ],
+ "versions": {
+ "0.28.0": {
+ "registries": {
+ "af-south-1": "626614931356",
+ "il-central-1": "780543022126",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ca-central-1": "763104351884",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-north-1": "763104351884",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "eu-south-1": "692866216735",
+ "me-south-1": "217643126080",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884",
+ "ca-west-1": "204538143572"
+ },
+ "repository": "djl-inference",
+ "tag_prefix": "0.28.0-lmi10.0.0-cu124"
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/sagemaker/image_uri_config/djl-neuronx.json b/src/sagemaker/image_uri_config/djl-neuronx.json
index a63acc87e4..6038946e28 100644
--- a/src/sagemaker/image_uri_config/djl-neuronx.json
+++ b/src/sagemaker/image_uri_config/djl-neuronx.json
@@ -3,6 +3,24 @@
"inference"
],
"versions": {
+ "0.28.0": {
+ "registries": {
+ "ap-northeast-1": "763104351884",
+ "ap-south-1": "763104351884",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "eu-central-1": "763104351884",
+ "eu-west-1": "763104351884",
+ "eu-west-3": "763104351884",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-west-2": "763104351884",
+ "ca-west-1": "204538143572"
+ },
+ "repository": "djl-inference",
+ "tag_prefix": "0.28.0-neuronx-sdk2.18.2"
+ },
"0.27.0": {
"registries": {
"ap-northeast-1": "763104351884",
diff --git a/src/sagemaker/image_uri_config/djl-tensorrtllm.json b/src/sagemaker/image_uri_config/djl-tensorrtllm.json
index e125cbd419..6cde6109bb 100644
--- a/src/sagemaker/image_uri_config/djl-tensorrtllm.json
+++ b/src/sagemaker/image_uri_config/djl-tensorrtllm.json
@@ -3,6 +3,38 @@
"inference"
],
"versions": {
+ "0.28.0": {
+ "registries": {
+ "af-south-1": "626614931356",
+ "il-central-1": "780543022126",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ca-central-1": "763104351884",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-north-1": "763104351884",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "eu-south-1": "692866216735",
+ "me-south-1": "217643126080",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884",
+ "ca-west-1": "204538143572"
+ },
+ "repository": "djl-inference",
+ "tag_prefix": "0.28.0-tensorrtllm0.9.0-cu122"
+ },
"0.27.0": {
"registries": {
"af-south-1": "626614931356",
diff --git a/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json b/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json
index 9da18c1b56..9efbbea305 100644
--- a/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json
+++ b/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json
@@ -4,7 +4,7 @@
"inf2"
],
"version_aliases": {
- "0.0": "0.0.16"
+ "0.0": "0.0.22"
},
"versions": {
"0.0.16": {
@@ -180,6 +180,35 @@
"container_version": {
"inf2": "ubuntu22.04"
}
+ },
+ "0.0.22": {
+ "py_versions": [
+ "py310"
+ ],
+ "registries": {
+ "ap-northeast-1": "763104351884",
+ "ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-4": "457447274322",
+ "eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
+ "eu-south-2": "503227376785",
+ "eu-west-1": "763104351884",
+ "eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-west-2": "763104351884",
+ "ca-west-1": "204538143572"
+ },
+ "tag_prefix": "2.1.2-optimum0.0.22",
+ "repository": "huggingface-pytorch-tgi-inference",
+ "container_version": {
+ "inf2": "ubuntu22.04"
+ }
}
}
}
diff --git a/src/sagemaker/image_uri_config/huggingface-llm.json b/src/sagemaker/image_uri_config/huggingface-llm.json
index 10073338e7..3e3f450d23 100644
--- a/src/sagemaker/image_uri_config/huggingface-llm.json
+++ b/src/sagemaker/image_uri_config/huggingface-llm.json
@@ -12,7 +12,7 @@
"1.2": "1.2.0",
"1.3": "1.3.3",
"1.4": "1.4.5",
- "2.0": "2.0.0"
+ "2.0": "2.0.2"
},
"versions": {
"0.6.0": {
@@ -578,6 +578,100 @@
"container_version": {
"gpu": "cu121-ubuntu22.04"
}
+ },
+ "2.0.1": {
+ "py_versions": [
+ "py310"
+ ],
+ "registries": {
+ "af-south-1": "626614931356",
+ "il-central-1": "780543022126",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ap-southeast-4": "457447274322",
+ "ca-central-1": "763104351884",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
+ "eu-north-1": "763104351884",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
+ "me-south-1": "217643126080",
+ "me-central-1": "914824155844",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
+ "us-iso-east-1": "886529160074",
+ "us-isob-east-1": "094389454867",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884",
+ "ca-west-1": "204538143572"
+ },
+ "tag_prefix": "2.1.1-tgi2.0.1",
+ "repository": "huggingface-pytorch-tgi-inference",
+ "container_version": {
+ "gpu": "cu121-ubuntu22.04"
+ }
+ },
+ "2.0.2": {
+ "py_versions": [
+ "py310"
+ ],
+ "registries": {
+ "af-south-1": "626614931356",
+ "il-central-1": "780543022126",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ap-southeast-4": "457447274322",
+ "ca-central-1": "763104351884",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
+ "eu-north-1": "763104351884",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
+ "me-south-1": "217643126080",
+ "me-central-1": "914824155844",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
+ "us-iso-east-1": "886529160074",
+ "us-isob-east-1": "094389454867",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884",
+ "ca-west-1": "204538143572"
+ },
+ "tag_prefix": "2.3.0-tgi2.0.2",
+ "repository": "huggingface-pytorch-tgi-inference",
+ "container_version": {
+ "gpu": "cu121-ubuntu22.04"
+ }
}
}
}
diff --git a/src/sagemaker/image_uri_config/huggingface-neuronx.json b/src/sagemaker/image_uri_config/huggingface-neuronx.json
index 0d8b7268b1..3721d75c5f 100644
--- a/src/sagemaker/image_uri_config/huggingface-neuronx.json
+++ b/src/sagemaker/image_uri_config/huggingface-neuronx.json
@@ -5,7 +5,8 @@
],
"version_aliases": {
"4.28": "4.28.1",
- "4.34": "4.34.1"
+ "4.34": "4.34.1",
+ "4.36": "4.36.2"
},
"versions": {
"4.28.1": {
@@ -79,6 +80,42 @@
"sdk2.15.0"
]
}
+ },
+ "4.36.2": {
+ "version_aliases": {
+ "pytorch1.13": "pytorch1.13.1"
+ },
+ "pytorch1.13.1": {
+ "py_versions": [
+ "py310"
+ ],
+ "repository": "huggingface-pytorch-inference-neuronx",
+ "registries": {
+ "ap-northeast-1": "763104351884",
+ "ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-4": "457447274322",
+ "eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
+ "eu-south-2": "503227376785",
+ "eu-west-1": "763104351884",
+ "eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-west-2": "763104351884",
+ "ca-west-1": "204538143572"
+ },
+ "container_version": {
+ "inf": "ubuntu20.04"
+ },
+ "sdk_versions": [
+ "sdk2.18.0"
+ ]
+ }
}
}
},
@@ -198,7 +235,8 @@
},
"4.36.2": {
"version_aliases": {
- "pytorch1.13": "pytorch1.13.1"
+ "pytorch1.13": "pytorch1.13.1",
+ "pytorch2.1": "pytorch2.1.2"
},
"pytorch1.13.1": {
"py_versions": [
@@ -246,6 +284,53 @@
"sdk_versions": [
"sdk2.16.1"
]
+ },
+ "pytorch2.1.2": {
+ "py_versions": [
+ "py310"
+ ],
+ "repository": "huggingface-pytorch-inference-neuronx",
+ "registries": {
+ "af-south-1": "626614931356",
+ "il-central-1": "780543022126",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-4": "457447274322",
+ "ca-central-1": "763104351884",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
+ "eu-north-1": "763104351884",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
+ "me-south-1": "217643126080",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
+ "us-iso-east-1": "886529160074",
+ "us-isob-east-1": "094389454867",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884",
+ "ca-west-1": "204538143572"
+ },
+ "container_version": {
+ "inf": "ubuntu20.04"
+ },
+ "sdk_versions": [
+ "sdk2.18.0"
+ ]
}
}
}
diff --git a/src/sagemaker/image_uri_config/huggingface-tei-cpu.json b/src/sagemaker/image_uri_config/huggingface-tei-cpu.json
new file mode 100644
index 0000000000..d68b0d6307
--- /dev/null
+++ b/src/sagemaker/image_uri_config/huggingface-tei-cpu.json
@@ -0,0 +1,59 @@
+{
+ "inference": {
+ "processors": [
+ "cpu"
+ ],
+ "version_aliases": {
+ "1.2": "1.2.3"
+ },
+ "versions": {
+ "1.2.3": {
+ "py_versions": [
+ "py310"
+ ],
+ "registries": {
+ "af-south-1": "510948584623",
+ "ap-east-1": "651117190479",
+ "ap-northeast-1": "354813040037",
+ "ap-northeast-2": "366743142698",
+ "ap-northeast-3": "867004704886",
+ "ap-south-1": "720646828776",
+ "ap-south-2": "628508329040",
+ "ap-southeast-1": "121021644041",
+ "ap-southeast-2": "783357654285",
+ "ap-southeast-3": "951798379941",
+ "ap-southeast-4": "106583098589",
+ "ca-central-1": "341280168497",
+ "ca-west-1": "190319476487",
+ "cn-north-1": "450853457545",
+ "cn-northwest-1": "451049120500",
+ "eu-central-1": "492215442770",
+ "eu-central-2": "680994064768",
+ "eu-north-1": "662702820516",
+ "eu-south-1": "978288397137",
+ "eu-south-2": "104374241257",
+ "eu-west-1": "141502667606",
+ "eu-west-2": "764974769150",
+ "eu-west-3": "659782779980",
+ "il-central-1": "898809789911",
+ "me-central-1": "272398656194",
+ "me-south-1": "801668240914",
+ "sa-east-1": "737474898029",
+ "us-east-1": "683313688378",
+ "us-east-2": "257758044811",
+ "us-gov-east-1": "237065988967",
+ "us-gov-west-1": "414596584902",
+ "us-iso-east-1": "833128469047",
+ "us-isob-east-1": "281123927165",
+ "us-west-1": "746614075791",
+ "us-west-2": "246618743249"
+ },
+ "tag_prefix": "2.0.1-tei1.2.3",
+ "repository": "tei-cpu",
+ "container_version": {
+ "cpu": "ubuntu22.04"
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/sagemaker/image_uri_config/huggingface-tei.json b/src/sagemaker/image_uri_config/huggingface-tei.json
new file mode 100644
index 0000000000..b7c597df18
--- /dev/null
+++ b/src/sagemaker/image_uri_config/huggingface-tei.json
@@ -0,0 +1,59 @@
+{
+ "inference": {
+ "processors": [
+ "gpu"
+ ],
+ "version_aliases": {
+ "1.2": "1.2.3"
+ },
+ "versions": {
+ "1.2.3": {
+ "py_versions": [
+ "py310"
+ ],
+ "registries": {
+ "af-south-1": "510948584623",
+ "ap-east-1": "651117190479",
+ "ap-northeast-1": "354813040037",
+ "ap-northeast-2": "366743142698",
+ "ap-northeast-3": "867004704886",
+ "ap-south-1": "720646828776",
+ "ap-south-2": "628508329040",
+ "ap-southeast-1": "121021644041",
+ "ap-southeast-2": "783357654285",
+ "ap-southeast-3": "951798379941",
+ "ap-southeast-4": "106583098589",
+ "ca-central-1": "341280168497",
+ "ca-west-1": "190319476487",
+ "cn-north-1": "450853457545",
+ "cn-northwest-1": "451049120500",
+ "eu-central-1": "492215442770",
+ "eu-central-2": "680994064768",
+ "eu-north-1": "662702820516",
+ "eu-south-1": "978288397137",
+ "eu-south-2": "104374241257",
+ "eu-west-1": "141502667606",
+ "eu-west-2": "764974769150",
+ "eu-west-3": "659782779980",
+ "il-central-1": "898809789911",
+ "me-central-1": "272398656194",
+ "me-south-1": "801668240914",
+ "sa-east-1": "737474898029",
+ "us-east-1": "683313688378",
+ "us-east-2": "257758044811",
+ "us-gov-east-1": "237065988967",
+ "us-gov-west-1": "414596584902",
+ "us-iso-east-1": "833128469047",
+ "us-isob-east-1": "281123927165",
+ "us-west-1": "746614075791",
+ "us-west-2": "246618743249"
+ },
+ "tag_prefix": "2.0.1-tei1.2.3",
+ "repository": "tei",
+ "container_version": {
+ "gpu": "cu122-ubuntu22.04"
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json
index faf7d6a14a..518da5f15d 100644
--- a/src/sagemaker/image_uri_config/pytorch-smp.json
+++ b/src/sagemaker/image_uri_config/pytorch-smp.json
@@ -6,8 +6,8 @@
"version_aliases": {
"2.0": "2.0.1",
"2.1": "2.1.2",
- "2.2": "2.3.0",
- "2.2.0": "2.3.0"
+ "2.2": "2.3.1",
+ "2.2.0": "2.3.1"
},
"versions": {
"2.0.1": {
@@ -109,6 +109,31 @@
"us-west-2": "658645717510"
},
"repository": "smdistributed-modelparallel"
+ },
+ "2.3.1": {
+ "py_versions": [
+ "py310"
+ ],
+ "registries": {
+ "ap-northeast-1": "658645717510",
+ "ap-northeast-2": "658645717510",
+ "ap-northeast-3": "658645717510",
+ "ap-south-1": "658645717510",
+ "ap-southeast-1": "658645717510",
+ "ap-southeast-2": "658645717510",
+ "ca-central-1": "658645717510",
+ "eu-central-1": "658645717510",
+ "eu-north-1": "658645717510",
+ "eu-west-1": "658645717510",
+ "eu-west-2": "658645717510",
+ "eu-west-3": "658645717510",
+ "sa-east-1": "658645717510",
+ "us-east-1": "658645717510",
+ "us-east-2": "658645717510",
+ "us-west-1": "658645717510",
+ "us-west-2": "658645717510"
+ },
+ "repository": "smdistributed-modelparallel"
}
}
}
diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json
index f068c68149..b846d51246 100644
--- a/src/sagemaker/image_uri_config/pytorch.json
+++ b/src/sagemaker/image_uri_config/pytorch.json
@@ -23,17 +23,17 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
- "eu-north-1": "763104351884",
"eu-central-2": "380420809688",
- "eu-west-1": "763104351884",
+ "eu-north-1": "763104351884",
"eu-south-2": "503227376785",
+ "eu-west-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference-eia"
},
@@ -47,13 +47,13 @@
"ap-south-2": "772153158452",
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
+ "ca-west-1": "204538143572",
"eu-central-2": "380420809688",
- "eu-west-1": "763104351884",
"eu-south-2": "503227376785",
+ "eu-west-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference-eia"
}
@@ -81,7 +81,8 @@
"1.12": "1.12.1",
"1.13": "1.13.1",
"2.0": "2.0.1",
- "2.1": "2.1.0"
+ "2.1": "2.1.0",
+ "2.2": "2.2.0"
},
"versions": {
"0.4.0": {
@@ -193,7 +194,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -205,18 +205,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -225,8 +227,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -237,7 +238,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -249,18 +249,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -269,18 +271,17 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
"1.4.0": {
"py_versions": [
- "py3"
+ "py3",
+ "py36"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -292,18 +293,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -312,18 +315,17 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
"1.5.0": {
"py_versions": [
- "py3"
+ "py3",
+ "py36"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -335,18 +337,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -355,8 +359,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -367,7 +370,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -379,18 +381,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -399,8 +403,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -411,7 +414,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -423,18 +425,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -443,8 +447,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -455,7 +458,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -467,18 +469,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -487,8 +491,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -499,7 +502,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -511,18 +513,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -531,8 +535,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -542,7 +545,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -554,18 +556,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -574,8 +578,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -585,7 +588,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -597,18 +599,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -617,8 +621,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -628,7 +631,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -640,18 +642,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -660,8 +664,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -671,7 +674,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -683,18 +685,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -703,8 +707,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -714,7 +717,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -726,18 +728,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -746,8 +750,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -757,7 +760,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -769,18 +771,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -789,8 +793,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -800,7 +803,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -812,16 +814,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -831,8 +836,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -842,7 +846,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -854,16 +857,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -873,8 +879,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -884,7 +889,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -896,16 +900,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -915,8 +922,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -926,7 +932,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -938,16 +943,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -957,8 +965,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -968,7 +975,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -980,16 +986,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -999,8 +1008,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
},
@@ -1010,7 +1018,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1022,16 +1029,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -1041,8 +1051,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-inference"
}
@@ -1060,12 +1069,14 @@
},
"versions": {
"1.12.1": {
+ "container_version": {
+ "cpu": "ubuntu20.04"
+ },
"py_versions": [
"py38"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1077,16 +1088,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -1096,21 +1110,19 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
- "repository": "pytorch-inference-graviton",
- "container_version": {
- "cpu": "ubuntu20.04"
- }
+ "repository": "pytorch-inference-graviton"
},
"2.0.0": {
+ "container_version": {
+ "cpu": "ubuntu20.04"
+ },
"py_versions": [
"py310"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1122,34 +1134,39 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
- "repository": "pytorch-inference-graviton",
- "container_version": {
- "cpu": "ubuntu20.04"
- }
+ "repository": "pytorch-inference-graviton"
},
"2.0.1": {
+ "container_version": {
+ "cpu": "ubuntu20.04"
+ },
"py_versions": [
"py310"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1161,34 +1178,39 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
- "repository": "pytorch-inference-graviton",
- "container_version": {
- "cpu": "ubuntu20.04"
- }
+ "repository": "pytorch-inference-graviton"
},
"2.1.0": {
+ "container_version": {
+ "cpu": "ubuntu20.04"
+ },
"py_versions": [
"py310"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1200,34 +1222,39 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
- "repository": "pytorch-inference-graviton",
- "container_version": {
- "cpu": "ubuntu20.04"
- }
+ "repository": "pytorch-inference-graviton"
},
"2.2.1": {
+ "container_version": {
+ "cpu": "ubuntu20.04"
+ },
"py_versions": [
"py310"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1239,16 +1266,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -1256,13 +1286,9 @@
"us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
- "repository": "pytorch-inference-graviton",
- "container_version": {
- "cpu": "ubuntu20.04"
- }
+ "repository": "pytorch-inference-graviton"
}
}
},
@@ -1289,7 +1315,8 @@
"1.13": "1.13.1",
"2.0": "2.0.1",
"2.1": "2.1.0",
- "2.2": "2.2.0"
+ "2.2": "2.2.0",
+ "2.3": "2.3.0"
},
"versions": {
"0.4.0": {
@@ -1401,7 +1428,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1413,18 +1439,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1433,8 +1461,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -1445,7 +1472,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1457,18 +1483,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1477,19 +1505,18 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
"1.4.0": {
"py_versions": [
"py2",
- "py3"
+ "py3",
+ "py36"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1501,18 +1528,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1521,18 +1550,17 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
"1.5.0": {
"py_versions": [
- "py3"
+ "py3",
+ "py36"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1544,18 +1572,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1564,8 +1594,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -1576,7 +1605,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1588,18 +1616,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1608,8 +1638,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -1620,7 +1649,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1632,18 +1660,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1652,8 +1682,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -1664,7 +1693,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1676,18 +1704,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1696,8 +1726,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -1708,7 +1737,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1720,18 +1748,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1740,8 +1770,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -1751,7 +1780,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1763,18 +1791,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1783,8 +1813,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -1794,7 +1823,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1806,18 +1834,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1826,8 +1856,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -1837,7 +1866,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1849,18 +1877,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1869,8 +1899,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -1880,7 +1909,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1892,18 +1920,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1912,8 +1942,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -1923,7 +1952,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1935,18 +1963,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1955,8 +1985,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -1966,7 +1995,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1978,18 +2006,20 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1998,8 +2028,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -2009,7 +2038,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2021,16 +2049,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2040,8 +2071,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -2051,7 +2081,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2063,16 +2092,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2082,8 +2114,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -2093,7 +2124,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2105,16 +2135,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2124,8 +2157,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -2135,7 +2167,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2147,16 +2178,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2166,8 +2200,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -2177,7 +2210,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2189,16 +2221,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2208,8 +2243,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
},
@@ -2219,7 +2253,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2231,16 +2264,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2250,11 +2286,51 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
+ },
+ "repository": "pytorch-training"
+ },
+ "2.3.0": {
+ "py_versions": [
+ "py311"
+ ],
+ "registries": {
+ "af-south-1": "626614931356",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ap-southeast-4": "457447274322",
+ "ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
+ "eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
+ "me-south-1": "217643126080",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884"
},
"repository": "pytorch-training"
}
}
}
-}
+}
\ No newline at end of file
diff --git a/src/sagemaker/image_uri_config/sagemaker-tritonserver.json b/src/sagemaker/image_uri_config/sagemaker-tritonserver.json
index 82397d913e..b2257ce803 100644
--- a/src/sagemaker/image_uri_config/sagemaker-tritonserver.json
+++ b/src/sagemaker/image_uri_config/sagemaker-tritonserver.json
@@ -7,7 +7,7 @@
"inference"
],
"versions": {
- "23.12": {
+ "24.03": {
"registries": {
"af-south-1": "626614931356",
"il-central-1": "780543022126",
@@ -37,7 +37,7 @@
"ca-west-1": "204538143572"
},
"repository": "sagemaker-tritonserver",
- "tag_prefix": "23.12-py3"
+ "tag_prefix": "24.03-py3"
},
"24.01": {
"registries": {
@@ -70,6 +70,38 @@
},
"repository": "sagemaker-tritonserver",
"tag_prefix": "24.01-py3"
+ },
+ "23.12": {
+ "registries": {
+ "af-south-1": "626614931356",
+ "il-central-1": "780543022126",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ca-central-1": "763104351884",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-north-1": "763104351884",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "eu-south-1": "692866216735",
+ "me-south-1": "217643126080",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884",
+ "ca-west-1": "204538143572"
+ },
+ "repository": "sagemaker-tritonserver",
+ "tag_prefix": "23.12-py3"
}
}
}
\ No newline at end of file
diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json
index 5dc8d35af2..9194dfbe4a 100644
--- a/src/sagemaker/image_uri_config/tensorflow.json
+++ b/src/sagemaker/image_uri_config/tensorflow.json
@@ -140,7 +140,6 @@
"1.14.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -152,6 +151,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -162,8 +162,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -172,15 +173,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference-eia"
},
"1.15.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -192,6 +191,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -202,8 +202,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -212,15 +213,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference-eia"
},
"2.0.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -232,6 +231,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -242,8 +242,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -252,15 +253,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference-eia"
},
"2.3.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -272,6 +271,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -282,8 +282,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -292,8 +293,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference-eia"
}
@@ -305,18 +305,18 @@
"gpu"
],
"version_aliases": {
- "1.10": "1.10.0",
- "1.11": "1.11.0",
- "1.12": "1.12.0",
- "1.13": "1.13.0",
- "1.14": "1.14.0",
- "1.15": "1.15.5",
"1.4": "1.4.1",
"1.5": "1.5.0",
"1.6": "1.6.0",
"1.7": "1.7.0",
"1.8": "1.8.0",
"1.9": "1.9.0",
+ "1.10": "1.10.0",
+ "1.11": "1.11.0",
+ "1.12": "1.12.0",
+ "1.13": "1.13.0",
+ "1.14": "1.14.0",
+ "1.15": "1.15.5",
"2.0": "2.0.4",
"2.1": "2.1.3",
"2.2": "2.2.2",
@@ -334,6 +334,204 @@
"2.14": "2.14.1"
},
"versions": {
+ "1.4.1": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
+ "1.5.0": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
+ "1.6.0": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
+ "1.7.0": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
+ "1.8.0": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
+ "1.9.0": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
"1.10.0": {
"py_versions": [
"py2"
@@ -430,7 +628,6 @@
"1.13.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -442,6 +639,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -452,8 +650,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -462,15 +661,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"1.14.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -482,6 +679,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -492,8 +690,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -502,15 +701,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"1.15.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -522,6 +719,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -532,8 +730,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -542,15 +741,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"1.15.2": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -562,6 +759,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -572,8 +770,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -582,15 +781,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"1.15.3": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -602,6 +799,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -612,8 +810,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -622,15 +821,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"1.15.4": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -642,6 +839,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -652,8 +850,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -662,15 +861,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"1.15.5": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -682,6 +879,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -692,8 +890,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -702,213 +901,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
- "1.4.1": {
- "py_versions": [
- "py2"
- ],
- "registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
- },
- "repository": "sagemaker-tensorflow"
- },
- "1.5.0": {
- "py_versions": [
- "py2"
- ],
- "registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
- },
- "repository": "sagemaker-tensorflow"
- },
- "1.6.0": {
- "py_versions": [
- "py2"
- ],
- "registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
- },
- "repository": "sagemaker-tensorflow"
- },
- "1.7.0": {
- "py_versions": [
- "py2"
- ],
- "registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
- },
- "repository": "sagemaker-tensorflow"
- },
- "1.8.0": {
- "py_versions": [
- "py2"
- ],
- "registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
- },
- "repository": "sagemaker-tensorflow"
- },
- "1.9.0": {
- "py_versions": [
- "py2"
- ],
- "registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
- },
- "repository": "sagemaker-tensorflow"
- },
"2.0.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -920,6 +919,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -930,8 +930,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -940,15 +941,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.0.1": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -960,6 +959,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -970,8 +970,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -980,15 +981,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.0.2": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1000,6 +999,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1010,8 +1010,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1020,15 +1021,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.0.3": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1040,6 +1039,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1050,8 +1050,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1060,15 +1061,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.0.4": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1080,6 +1079,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1090,8 +1090,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1100,15 +1101,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.1.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1120,6 +1119,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1130,8 +1130,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1140,15 +1141,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.1.1": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1160,6 +1159,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1170,8 +1170,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1180,15 +1181,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.1.2": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1200,6 +1199,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1210,8 +1210,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1220,15 +1221,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.1.3": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1240,6 +1239,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1250,8 +1250,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1260,15 +1261,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.2.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1280,6 +1279,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1290,8 +1290,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1300,15 +1301,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.2.1": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1320,6 +1319,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1330,8 +1330,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1340,15 +1341,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.2.2": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1360,6 +1359,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1370,8 +1370,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1380,15 +1381,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.3.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1400,6 +1399,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1410,8 +1410,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1420,15 +1421,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.3.1": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1440,6 +1439,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1450,8 +1450,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1460,15 +1461,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.3.2": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1480,6 +1479,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1490,8 +1490,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1500,15 +1501,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.4.1": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1520,6 +1519,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1530,8 +1530,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1540,15 +1541,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.4.3": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1560,6 +1559,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1570,8 +1570,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1580,15 +1581,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.5.1": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1600,6 +1599,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1610,8 +1610,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1620,15 +1621,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.6.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1640,6 +1639,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1650,8 +1650,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1660,15 +1661,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.6.3": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1680,6 +1679,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1690,8 +1690,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1700,15 +1701,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.7.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1720,6 +1719,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1730,8 +1730,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1740,15 +1741,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.8.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1760,6 +1759,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1770,8 +1770,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1780,15 +1781,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.8.4": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1800,6 +1799,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1810,8 +1810,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -1820,15 +1821,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.9.2": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1840,6 +1839,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1850,6 +1850,8 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -1859,15 +1861,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.9.3": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1879,6 +1879,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1889,6 +1890,8 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -1898,15 +1901,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.10.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1918,6 +1919,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1928,6 +1930,8 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -1937,15 +1941,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.10.1": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1957,6 +1959,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -1967,6 +1970,8 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -1976,15 +1981,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.11.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -1996,6 +1999,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -2006,6 +2010,8 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2015,15 +2021,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.11.1": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2035,6 +2039,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -2045,6 +2050,8 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2054,15 +2061,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.12.1": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2074,6 +2079,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -2084,6 +2090,8 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2093,15 +2101,13 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
"2.13.0": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2113,6 +2119,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -2123,17 +2130,18 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
- "us-iso-east-1": "886529160074",
- "us-isob-east-1": "094389454867",
"us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
+ "us-iso-east-1": "886529160074",
+ "us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-inference"
},
@@ -2191,12 +2199,14 @@
},
"versions": {
"2.9.1": {
+ "container_version": {
+ "cpu": "ubuntu20.04"
+ },
"py_versions": [
"py38"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2208,16 +2218,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2227,21 +2240,19 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
- "repository": "tensorflow-inference-graviton",
- "container_version": {
- "cpu": "ubuntu20.04"
- }
+ "repository": "tensorflow-inference-graviton"
},
"2.12.1": {
+ "container_version": {
+ "cpu": "ubuntu20.04"
+ },
"py_versions": [
"py310"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2253,16 +2264,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2272,21 +2286,19 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
- "repository": "tensorflow-inference-graviton",
- "container_version": {
- "cpu": "ubuntu20.04"
- }
+ "repository": "tensorflow-inference-graviton"
},
"2.13.0": {
+ "container_version": {
+ "cpu": "ubuntu20.04"
+ },
"py_versions": [
"py310"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2298,16 +2310,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2317,21 +2332,19 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
- "repository": "tensorflow-inference-graviton",
- "container_version": {
- "cpu": "ubuntu20.04"
- }
+ "repository": "tensorflow-inference-graviton"
},
"2.14.1": {
+ "container_version": {
+ "cpu": "ubuntu20.04"
+ },
"py_versions": [
"py310"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2343,16 +2356,19 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-central-2": "380420809688",
"eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -2362,13 +2378,9 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
- "repository": "tensorflow-inference-graviton",
- "container_version": {
- "cpu": "ubuntu20.04"
- }
+ "repository": "tensorflow-inference-graviton"
}
}
},
@@ -2378,18 +2390,18 @@
"gpu"
],
"version_aliases": {
- "1.10": "1.10.0",
- "1.11": "1.11.0",
- "1.12": "1.12.0",
- "1.13": "1.13.1",
- "1.14": "1.14.0",
- "1.15": "1.15.5",
"1.4": "1.4.1",
"1.5": "1.5.0",
"1.6": "1.6.0",
"1.7": "1.7.0",
"1.8": "1.8.0",
"1.9": "1.9.0",
+ "1.10": "1.10.0",
+ "1.11": "1.11.0",
+ "1.12": "1.12.0",
+ "1.13": "1.13.1",
+ "1.14": "1.14.0",
+ "1.15": "1.15.5",
"2.0": "2.0.4",
"2.1": "2.1.3",
"2.2": "2.2.2",
@@ -2407,6 +2419,204 @@
"2.14": "2.14.1"
},
"versions": {
+ "1.4.1": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
+ "1.5.0": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
+ "1.6.0": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
+ "1.7.0": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
+ "1.8.0": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
+ "1.9.0": {
+ "py_versions": [
+ "py2"
+ ],
+ "registries": {
+ "af-south-1": "313743910680",
+ "ap-east-1": "057415533634",
+ "ap-northeast-1": "520713654638",
+ "ap-northeast-2": "520713654638",
+ "ap-south-1": "520713654638",
+ "ap-southeast-1": "520713654638",
+ "ap-southeast-2": "520713654638",
+ "ca-central-1": "520713654638",
+ "cn-north-1": "422961961927",
+ "cn-northwest-1": "423003514399",
+ "eu-central-1": "520713654638",
+ "eu-north-1": "520713654638",
+ "eu-south-1": "048378556238",
+ "eu-west-1": "520713654638",
+ "eu-west-2": "520713654638",
+ "eu-west-3": "520713654638",
+ "me-south-1": "724002660598",
+ "sa-east-1": "520713654638",
+ "us-east-1": "520713654638",
+ "us-east-2": "520713654638",
+ "us-gov-west-1": "246785580436",
+ "us-iso-east-1": "744548109606",
+ "us-isob-east-1": "453391408702",
+ "us-west-1": "520713654638",
+ "us-west-2": "520713654638"
+ },
+ "repository": "sagemaker-tensorflow"
+ },
"1.10.0": {
"py_versions": [
"py2"
@@ -2542,7 +2752,6 @@
"py3": {
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2554,6 +2763,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -2564,8 +2774,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -2574,8 +2785,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
}
@@ -2587,7 +2797,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2599,6 +2808,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -2609,8 +2819,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -2619,8 +2830,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -2631,7 +2841,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2643,6 +2852,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -2653,8 +2863,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -2663,8 +2874,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -2676,7 +2886,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2688,6 +2897,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -2698,8 +2908,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -2708,8 +2919,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -2721,7 +2931,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2733,6 +2942,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -2743,8 +2953,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -2753,8 +2964,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -2766,7 +2976,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -2778,6 +2987,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -2788,53 +2998,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
- "me-central-1": "914824155844",
- "sa-east-1": "763104351884",
- "us-east-1": "763104351884",
- "us-east-2": "763104351884",
- "us-gov-east-1": "446045086412",
- "us-gov-west-1": "442386744353",
- "us-iso-east-1": "886529160074",
- "us-isob-east-1": "094389454867",
- "us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
- },
- "repository": "tensorflow-training"
- },
- "1.15.5": {
- "py_versions": [
- "py3",
- "py36",
- "py37"
- ],
- "registries": {
- "af-south-1": "626614931356",
"il-central-1": "780543022126",
- "ap-east-1": "871362719292",
- "ap-northeast-1": "763104351884",
- "ap-northeast-2": "763104351884",
- "ap-northeast-3": "364406365360",
- "ap-south-1": "763104351884",
- "ap-south-2": "772153158452",
- "ap-southeast-1": "763104351884",
- "ap-southeast-2": "763104351884",
- "ap-southeast-3": "907027046896",
- "ap-southeast-4": "457447274322",
- "ca-central-1": "763104351884",
- "cn-north-1": "727897471807",
- "cn-northwest-1": "727897471807",
- "eu-central-1": "763104351884",
- "eu-central-2": "380420809688",
- "eu-north-1": "763104351884",
- "eu-south-1": "692866216735",
- "eu-south-2": "503227376785",
- "eu-west-1": "763104351884",
- "eu-west-2": "763104351884",
- "eu-west-3": "763104351884",
- "me-south-1": "217643126080",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -2843,208 +3009,54 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
- "1.4.1": {
- "py_versions": [
- "py2"
- ],
- "registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
- },
- "repository": "sagemaker-tensorflow"
- },
- "1.5.0": {
- "py_versions": [
- "py2"
- ],
- "registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
- },
- "repository": "sagemaker-tensorflow"
- },
- "1.6.0": {
- "py_versions": [
- "py2"
- ],
- "registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
- },
- "repository": "sagemaker-tensorflow"
- },
- "1.7.0": {
- "py_versions": [
- "py2"
- ],
- "registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
- },
- "repository": "sagemaker-tensorflow"
- },
- "1.8.0": {
- "py_versions": [
- "py2"
- ],
- "registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
- },
- "repository": "sagemaker-tensorflow"
- },
- "1.9.0": {
+ "1.15.5": {
"py_versions": [
- "py2"
+ "py3",
+ "py36",
+ "py37"
],
"registries": {
- "af-south-1": "313743910680",
- "ap-east-1": "057415533634",
- "ap-northeast-1": "520713654638",
- "ap-northeast-2": "520713654638",
- "ap-south-1": "520713654638",
- "ap-southeast-1": "520713654638",
- "ap-southeast-2": "520713654638",
- "ca-central-1": "520713654638",
- "cn-north-1": "422961961927",
- "cn-northwest-1": "423003514399",
- "eu-central-1": "520713654638",
- "eu-north-1": "520713654638",
- "eu-south-1": "048378556238",
- "eu-west-1": "520713654638",
- "eu-west-2": "520713654638",
- "eu-west-3": "520713654638",
- "me-south-1": "724002660598",
- "sa-east-1": "520713654638",
- "us-east-1": "520713654638",
- "us-east-2": "520713654638",
- "us-gov-west-1": "246785580436",
- "us-iso-east-1": "744548109606",
- "us-isob-east-1": "453391408702",
- "us-west-1": "520713654638",
- "us-west-2": "520713654638"
+ "af-south-1": "626614931356",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ap-southeast-4": "457447274322",
+ "ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
+ "eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
+ "me-south-1": "217643126080",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
+ "us-iso-east-1": "886529160074",
+ "us-isob-east-1": "094389454867",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884"
},
- "repository": "sagemaker-tensorflow"
+ "repository": "tensorflow-training"
},
"2.0.0": {
"py_versions": [
@@ -3053,7 +3065,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3065,6 +3076,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3075,8 +3087,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3085,8 +3098,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3097,7 +3109,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3109,6 +3120,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3119,8 +3131,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3129,8 +3142,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3141,7 +3153,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3153,6 +3164,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3163,8 +3175,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3173,18 +3186,17 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
"2.0.3": {
"py_versions": [
- "py3"
+ "py3",
+ "py36"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3196,6 +3208,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3206,8 +3219,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3216,18 +3230,17 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
"2.0.4": {
"py_versions": [
- "py3"
+ "py3",
+ "py36"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3239,6 +3252,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3249,8 +3263,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3259,8 +3274,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3271,7 +3285,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3283,6 +3296,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3293,8 +3307,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3303,8 +3318,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3315,7 +3329,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3327,6 +3340,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3337,8 +3351,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3347,18 +3362,17 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
"2.1.2": {
"py_versions": [
- "py3"
+ "py3",
+ "py36"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3370,6 +3384,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3380,8 +3395,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3390,18 +3406,17 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
"2.1.3": {
"py_versions": [
- "py3"
+ "py3",
+ "py36"
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3413,6 +3428,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3423,8 +3439,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3433,8 +3450,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3444,7 +3460,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3456,6 +3471,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3466,8 +3482,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3476,8 +3493,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3487,7 +3503,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3499,6 +3514,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3509,8 +3525,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3519,8 +3536,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3530,7 +3546,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3542,6 +3557,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3552,8 +3568,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3562,8 +3579,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3573,7 +3589,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3585,6 +3600,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3595,8 +3611,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3605,8 +3622,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3616,7 +3632,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3628,6 +3643,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3638,8 +3654,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3648,8 +3665,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3659,7 +3675,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3671,6 +3686,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3681,8 +3697,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3691,8 +3708,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3702,7 +3718,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3714,6 +3729,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3724,8 +3740,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3734,8 +3751,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3745,7 +3761,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3757,6 +3772,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3767,8 +3783,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3777,8 +3794,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3788,7 +3804,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3800,6 +3815,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3810,8 +3826,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3820,8 +3837,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3831,7 +3847,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3843,6 +3858,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3853,8 +3869,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3863,8 +3880,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3874,7 +3890,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3886,6 +3901,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3896,8 +3912,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3906,8 +3923,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3917,7 +3933,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3929,6 +3944,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3939,8 +3955,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3949,8 +3966,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -3960,7 +3976,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -3972,6 +3987,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -3982,8 +3998,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -3992,8 +4009,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -4003,7 +4019,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -4015,6 +4030,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -4025,8 +4041,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -4035,8 +4052,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -4046,7 +4062,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -4058,6 +4073,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -4068,8 +4084,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -4078,8 +4095,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -4089,7 +4105,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -4101,6 +4116,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -4111,8 +4127,9 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
- "me-south-1": "217643126080",
+ "il-central-1": "780543022126",
"me-central-1": "914824155844",
+ "me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -4121,8 +4138,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -4132,7 +4148,6 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
@@ -4144,6 +4159,7 @@
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
@@ -4154,6 +4170,8 @@
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -4163,8 +4181,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -4174,25 +4191,30 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -4202,8 +4224,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -4213,30 +4234,38 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -4246,25 +4275,30 @@
],
"registries": {
"af-south-1": "626614931356",
- "il-central-1": "780543022126",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
"ap-southeast-4": "457447274322",
"ca-central-1": "763104351884",
+ "ca-west-1": "204538143572",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
+ "il-central-1": "780543022126",
+ "me-central-1": "914824155844",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
@@ -4274,8 +4308,7 @@
"us-iso-east-1": "886529160074",
"us-isob-east-1": "094389454867",
"us-west-1": "763104351884",
- "us-west-2": "763104351884",
- "ca-west-1": "204538143572"
+ "us-west-2": "763104351884"
},
"repository": "tensorflow-training"
},
@@ -4324,4 +4357,4 @@
}
}
}
-}
+}
\ No newline at end of file
diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py
index 143ecc9bdb..743f6b1f99 100644
--- a/src/sagemaker/image_uris.py
+++ b/src/sagemaker/image_uris.py
@@ -37,6 +37,8 @@
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
HUGGING_FACE_FRAMEWORK = "huggingface"
HUGGING_FACE_LLM_FRAMEWORK = "huggingface-llm"
+HUGGING_FACE_TEI_GPU_FRAMEWORK = "huggingface-tei"
+HUGGING_FACE_TEI_CPU_FRAMEWORK = "huggingface-tei-cpu"
HUGGING_FACE_LLM_NEURONX_FRAMEWORK = "huggingface-llm-neuronx"
XGBOOST_FRAMEWORK = "xgboost"
SKLEARN_FRAMEWORK = "sklearn"
@@ -68,6 +70,7 @@ def retrieve(
inference_tool=None,
serverless_inference_config=None,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name=None,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.
@@ -121,6 +124,7 @@ def retrieve(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The ECR URI for the corresponding SageMaker Docker image.
@@ -160,6 +164,7 @@ def retrieve(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):
@@ -477,6 +482,8 @@ def _validate_version_and_set_if_needed(version, config, framework):
if version is None and framework in [
DATA_WRANGLER_FRAMEWORK,
HUGGING_FACE_LLM_FRAMEWORK,
+ HUGGING_FACE_TEI_GPU_FRAMEWORK,
+ HUGGING_FACE_TEI_CPU_FRAMEWORK,
HUGGING_FACE_LLM_NEURONX_FRAMEWORK,
STABILITYAI_FRAMEWORK,
]:
diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py
index 48aaab0ac8..c4af4b2036 100644
--- a/src/sagemaker/instance_types.py
+++ b/src/sagemaker/instance_types.py
@@ -36,6 +36,7 @@ def retrieve_default(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> str:
"""Retrieves the default instance type for the model matching the given arguments.
@@ -64,6 +65,7 @@ def retrieve_default(
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default instance type to use for the model.
@@ -88,6 +90,7 @@ def retrieve_default(
sagemaker_session=sagemaker_session,
training_instance_type=training_instance_type,
model_type=model_type,
+ config_name=config_name,
)
diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py
index c28c27ed4e..fcb3ce3bf2 100644
--- a/src/sagemaker/jumpstart/artifacts/environment_variables.py
+++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py
@@ -39,6 +39,7 @@ def _retrieve_default_environment_variables(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
+ config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the inference environment variables for the model matching the given arguments.
@@ -68,6 +69,7 @@ def _retrieve_default_environment_variables(
environment variables specific for the instance type.
script (JumpStartScriptScope): The JumpStart script for which to retrieve
environment variables.
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the inference environment variables to use for the model.
"""
@@ -84,6 +86,7 @@ def _retrieve_default_environment_variables(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
default_environment_variables: Dict[str, str] = {}
@@ -121,6 +124,7 @@ def _retrieve_default_environment_variables(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
instance_type=instance_type,
+ config_name=config_name,
)
)
@@ -167,6 +171,7 @@ def _retrieve_gated_model_uri_env_var_value(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
+ config_name: Optional[str] = None,
) -> Optional[str]:
"""Retrieves the gated model env var URI matching the given arguments.
@@ -190,6 +195,7 @@ def _retrieve_gated_model_uri_env_var_value(
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
environment variables specific for the instance type.
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
@@ -211,6 +217,7 @@ def _retrieve_gated_model_uri_env_var_value(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
s3_key: Optional[str] = (
diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py
index d19530ecfb..67db7d260f 100644
--- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py
+++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py
@@ -36,6 +36,7 @@ def _retrieve_default_hyperparameters(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
+ config_name: Optional[str] = None,
):
"""Retrieves the training hyperparameters for the model matching the given arguments.
@@ -66,6 +67,7 @@ def _retrieve_default_hyperparameters(
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get hyperparameters
specific for the instance type.
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the hyperparameters to use for the model.
"""
@@ -82,6 +84,7 @@ def _retrieve_default_hyperparameters(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
default_hyperparameters: Dict[str, str] = {}
diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py
index 9d19d5e069..72633320f5 100644
--- a/src/sagemaker/jumpstart/artifacts/image_uris.py
+++ b/src/sagemaker/jumpstart/artifacts/image_uris.py
@@ -46,6 +46,7 @@ def _retrieve_image_uri(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
):
"""Retrieves the container image URI for JumpStart models.
@@ -95,6 +96,7 @@ def _retrieve_image_uri(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the ECR URI for the corresponding SageMaker Docker image.
@@ -116,6 +118,7 @@ def _retrieve_image_uri(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
if image_scope == JumpStartScriptScope.INFERENCE:
@@ -200,4 +203,5 @@ def _retrieve_image_uri(
distribution=distribution,
base_framework_version=base_framework_version_override or base_framework_version,
training_compiler_config=training_compiler_config,
+ config_name=config_name,
)
diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py
index 1b3c6f4b29..8bbe089354 100644
--- a/src/sagemaker/jumpstart/artifacts/incremental_training.py
+++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py
@@ -33,6 +33,7 @@ def _model_supports_incremental_training(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> bool:
"""Returns True if the model supports incremental training.
@@ -54,6 +55,7 @@ def _model_supports_incremental_training(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
bool: the support status for incremental training.
"""
@@ -70,6 +72,7 @@ def _model_supports_incremental_training(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
return model_specs.supports_incremental_training()
diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py
index e7c9c5911d..f4bf212c1c 100644
--- a/src/sagemaker/jumpstart/artifacts/instance_types.py
+++ b/src/sagemaker/jumpstart/artifacts/instance_types.py
@@ -40,6 +40,7 @@ def _retrieve_default_instance_type(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> str:
"""Retrieves the default instance type for the model.
@@ -68,6 +69,7 @@ def _retrieve_default_instance_type(
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the default instance type to use for the model or None.
@@ -89,6 +91,7 @@ def _retrieve_default_instance_type(
tolerate_deprecated_model=tolerate_deprecated_model,
model_type=model_type,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
if scope == JumpStartScriptScope.INFERENCE:
@@ -128,6 +131,7 @@ def _retrieve_instance_types(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
+ config_name: Optional[str] = None,
) -> List[str]:
"""Retrieves the supported instance types for the model.
@@ -156,6 +160,7 @@ def _retrieve_instance_types(
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
list: the supported instance types to use for the model or None.
@@ -176,6 +181,7 @@ def _retrieve_instance_types(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
if scope == JumpStartScriptScope.INFERENCE:
diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py
index 9cd152b0bb..ceb88d9b26 100644
--- a/src/sagemaker/jumpstart/artifacts/kwargs.py
+++ b/src/sagemaker/jumpstart/artifacts/kwargs.py
@@ -37,6 +37,7 @@ def _retrieve_model_init_kwargs(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> dict:
"""Retrieves kwargs for `Model`.
@@ -58,6 +59,7 @@ def _retrieve_model_init_kwargs(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the kwargs to use for the use case.
"""
@@ -75,6 +77,7 @@ def _retrieve_model_init_kwargs(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
kwargs = deepcopy(model_specs.model_kwargs)
@@ -94,6 +97,7 @@ def _retrieve_model_deploy_kwargs(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> dict:
"""Retrieves kwargs for `Model.deploy`.
@@ -117,6 +121,7 @@ def _retrieve_model_deploy_kwargs(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the kwargs to use for the use case.
@@ -135,6 +140,7 @@ def _retrieve_model_deploy_kwargs(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None:
@@ -151,6 +157,7 @@ def _retrieve_estimator_init_kwargs(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> dict:
"""Retrieves kwargs for `Estimator`.
@@ -174,6 +181,7 @@ def _retrieve_estimator_init_kwargs(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the kwargs to use for the use case.
"""
@@ -190,6 +198,7 @@ def _retrieve_estimator_init_kwargs(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
kwargs = deepcopy(model_specs.estimator_kwargs)
@@ -210,6 +219,7 @@ def _retrieve_estimator_fit_kwargs(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> dict:
"""Retrieves kwargs for `Estimator.fit`.
@@ -231,6 +241,7 @@ def _retrieve_estimator_fit_kwargs(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the kwargs to use for the use case.
@@ -248,6 +259,7 @@ def _retrieve_estimator_fit_kwargs(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
return model_specs.fit_kwargs
diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py
index 57f66155c7..f23b66aed4 100644
--- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py
+++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py
@@ -35,6 +35,7 @@ def _retrieve_default_training_metric_definitions(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
+ config_name: Optional[str] = None,
) -> Optional[List[Dict[str, str]]]:
"""Retrieves the default training metric definitions for the model.
@@ -58,6 +59,7 @@ def _retrieve_default_training_metric_definitions(
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
metric definitions specific for the instance type.
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
list: the default training metric definitions to use for the model or None.
"""
@@ -74,6 +76,7 @@ def _retrieve_default_training_metric_definitions(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
default_metric_definitions = (
diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py
index aa22351771..67459519f3 100644
--- a/src/sagemaker/jumpstart/artifacts/model_packages.py
+++ b/src/sagemaker/jumpstart/artifacts/model_packages.py
@@ -37,6 +37,7 @@ def _retrieve_model_package_arn(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> Optional[str]:
"""Retrieves associated model pacakge arn for the model.
@@ -60,6 +61,7 @@ def _retrieve_model_package_arn(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the model package arn to use for the model or None.
@@ -78,6 +80,7 @@ def _retrieve_model_package_arn(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
if scope == JumpStartScriptScope.INFERENCE:
@@ -93,7 +96,10 @@ def _retrieve_model_package_arn(
if instance_specific_arn is not None:
return instance_specific_arn
- if model_specs.hosting_model_package_arns is None:
+ if (
+ model_specs.hosting_model_package_arns is None
+ or model_specs.hosting_model_package_arns == {}
+ ):
return None
regional_arn = model_specs.hosting_model_package_arns.get(region)
@@ -118,6 +124,7 @@ def _retrieve_model_package_model_artifact_s3_uri(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> Optional[str]:
"""Retrieves s3 artifact uri associated with model package.
@@ -141,6 +148,7 @@ def _retrieve_model_package_model_artifact_s3_uri(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the model package artifact uri to use for the model or None.
@@ -162,6 +170,7 @@ def _retrieve_model_package_model_artifact_s3_uri(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
if model_specs.training_model_package_artifact_uris is None:
diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py
index 6bb2e576fc..00c6d8b9aa 100644
--- a/src/sagemaker/jumpstart/artifacts/model_uris.py
+++ b/src/sagemaker/jumpstart/artifacts/model_uris.py
@@ -95,6 +95,7 @@ def _retrieve_model_uri(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
):
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
@@ -120,6 +121,8 @@ def _retrieve_model_uri(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
+
Returns:
str: the model artifact S3 URI for the corresponding model.
@@ -141,6 +144,7 @@ def _retrieve_model_uri(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
model_artifact_key: str
@@ -182,6 +186,7 @@ def _model_supports_training_model_uri(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> bool:
"""Returns True if the model supports training with model uri field.
@@ -203,6 +208,7 @@ def _model_supports_training_model_uri(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
bool: the support status for model uri with training.
"""
@@ -219,6 +225,7 @@ def _model_supports_training_model_uri(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
return model_specs.use_training_model_artifact()
diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py
index 3359e32732..2f4a8bb0ac 100644
--- a/src/sagemaker/jumpstart/artifacts/payloads.py
+++ b/src/sagemaker/jumpstart/artifacts/payloads.py
@@ -37,6 +37,7 @@ def _retrieve_example_payloads(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> Optional[Dict[str, JumpStartSerializablePayload]]:
"""Returns example payloads.
@@ -58,6 +59,7 @@ def _retrieve_example_payloads(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases
to the serializable payload object.
@@ -76,6 +78,7 @@ def _retrieve_example_payloads(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
default_payloads = model_specs.default_payloads
diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py
index 4f6dfe1fe3..635f063e05 100644
--- a/src/sagemaker/jumpstart/artifacts/predictors.py
+++ b/src/sagemaker/jumpstart/artifacts/predictors.py
@@ -78,6 +78,7 @@ def _retrieve_default_deserializer(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> BaseDeserializer:
"""Retrieves the default deserializer for the model.
@@ -98,6 +99,7 @@ def _retrieve_default_deserializer(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
BaseDeserializer: the default deserializer to use for the model.
@@ -111,6 +113,7 @@ def _retrieve_default_deserializer(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type))
@@ -124,6 +127,7 @@ def _retrieve_default_serializer(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> BaseSerializer:
"""Retrieves the default serializer for the model.
@@ -144,6 +148,7 @@ def _retrieve_default_serializer(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
BaseSerializer: the default serializer to use for the model.
"""
@@ -156,6 +161,7 @@ def _retrieve_default_serializer(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type))
@@ -169,6 +175,7 @@ def _retrieve_deserializer_options(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> List[BaseDeserializer]:
"""Retrieves the supported deserializers for the model.
@@ -189,6 +196,7 @@ def _retrieve_deserializer_options(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
List[BaseDeserializer]: the supported deserializers to use for the model.
"""
@@ -201,6 +209,7 @@ def _retrieve_deserializer_options(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
seen_classes: Set[Type] = set()
@@ -227,6 +236,7 @@ def _retrieve_serializer_options(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> List[BaseSerializer]:
"""Retrieves the supported serializers for the model.
@@ -247,6 +257,7 @@ def _retrieve_serializer_options(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
List[BaseSerializer]: the supported serializers to use for the model.
"""
@@ -258,6 +269,7 @@ def _retrieve_serializer_options(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
seen_classes: Set[Type] = set()
@@ -285,6 +297,7 @@ def _retrieve_default_content_type(
tolerate_deprecated_model: bool = False,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> str:
"""Retrieves the default content type for the model.
@@ -305,6 +318,7 @@ def _retrieve_default_content_type(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the default content type to use for the model.
"""
@@ -322,6 +336,7 @@ def _retrieve_default_content_type(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
default_content_type = model_specs.predictor_specs.default_content_type
@@ -336,6 +351,7 @@ def _retrieve_default_accept_type(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> str:
"""Retrieves the default accept type for the model.
@@ -356,6 +372,7 @@ def _retrieve_default_accept_type(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the default accept type to use for the model.
"""
@@ -373,6 +390,7 @@ def _retrieve_default_accept_type(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
default_accept_type = model_specs.predictor_specs.default_accept_type
@@ -388,6 +406,7 @@ def _retrieve_supported_accept_types(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> List[str]:
"""Retrieves the supported accept types for the model.
@@ -408,6 +427,7 @@ def _retrieve_supported_accept_types(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
list: the supported accept types to use for the model.
"""
@@ -425,6 +445,7 @@ def _retrieve_supported_accept_types(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
supported_accept_types = model_specs.predictor_specs.supported_accept_types
@@ -440,6 +461,7 @@ def _retrieve_supported_content_types(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> List[str]:
"""Retrieves the supported content types for the model.
@@ -460,6 +482,7 @@ def _retrieve_supported_content_types(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
list: the supported content types to use for the model.
"""
@@ -477,6 +500,7 @@ def _retrieve_supported_content_types(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
supported_content_types = model_specs.predictor_specs.supported_content_types
diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py
index cffd46d043..b4fdac770b 100644
--- a/src/sagemaker/jumpstart/artifacts/resource_names.py
+++ b/src/sagemaker/jumpstart/artifacts/resource_names.py
@@ -35,6 +35,8 @@ def _retrieve_resource_name_base(
tolerate_deprecated_model: bool = False,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
+ config_name: Optional[str] = None,
) -> bool:
"""Returns default resource name.
@@ -56,6 +58,7 @@ def _retrieve_resource_name_base(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config. (Default: None).
Returns:
str: the default resource name.
"""
@@ -67,12 +70,13 @@ def _retrieve_resource_name_base(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
- scope=JumpStartScriptScope.INFERENCE,
+ scope=scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
model_type=model_type,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
return model_specs.resource_name_base
diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py
index 369acac85f..49126da336 100644
--- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py
+++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py
@@ -54,6 +54,7 @@ def _retrieve_default_resources(
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
+ config_name: Optional[str] = None,
) -> ResourceRequirements:
"""Retrieves the default resource requirements for the model.
@@ -79,6 +80,7 @@ def _retrieve_default_resources(
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
host requirements specific for the instance type.
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default resource requirements to use for the model or None.
@@ -102,6 +104,7 @@ def _retrieve_default_resources(
tolerate_deprecated_model=tolerate_deprecated_model,
model_type=model_type,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
if scope == JumpStartScriptScope.INFERENCE:
diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py
index f69732d2e0..97313ec626 100644
--- a/src/sagemaker/jumpstart/artifacts/script_uris.py
+++ b/src/sagemaker/jumpstart/artifacts/script_uris.py
@@ -37,6 +37,7 @@ def _retrieve_script_uri(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
):
"""Retrieves the script S3 URI associated with the model matching the given arguments.
@@ -62,6 +63,7 @@ def _retrieve_script_uri(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the model script URI for the corresponding model.
@@ -83,6 +85,7 @@ def _retrieve_script_uri(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
if script_scope == JumpStartScriptScope.INFERENCE:
@@ -108,6 +111,7 @@ def _model_supports_inference_script_uri(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> bool:
"""Returns True if the model supports inference with script uri field.
@@ -145,6 +149,7 @@ def _model_supports_inference_script_uri(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
return model_specs.use_inference_script_uri()
diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py
index ca49fd41a3..9666ce828f 100644
--- a/src/sagemaker/jumpstart/enums.py
+++ b/src/sagemaker/jumpstart/enums.py
@@ -93,6 +93,9 @@ class JumpStartTag(str, Enum):
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
+ INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name"
+ TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name"
+
class SerializerType(str, Enum):
"""Enum class for serializers associated with JumpStart models."""
diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py
index 88927ae931..5f7e0ed82c 100644
--- a/src/sagemaker/jumpstart/estimator.py
+++ b/src/sagemaker/jumpstart/estimator.py
@@ -33,8 +33,10 @@
from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs
from sagemaker.jumpstart.factory.model import get_default_predictor
-from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job
+from sagemaker.jumpstart.session_utils import get_model_info_from_training_job
+from sagemaker.jumpstart.types import JumpStartMetadataConfig
from sagemaker.jumpstart.utils import (
+ get_jumpstart_configs,
validate_model_id_and_get_type,
resolve_model_sagemaker_config_field,
verify_model_region_and_return_specs,
@@ -109,6 +111,8 @@ def __init__(
container_arguments: Optional[List[str]] = None,
disable_output_compression: Optional[bool] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
+ config_name: Optional[str] = None,
+ enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
):
"""Initializes a ``JumpStartEstimator``.
@@ -500,6 +504,10 @@ def __init__(
to Amazon S3 without compression after training finishes.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
+ config_name (Optional[str]):
+ Name of the training configuration to apply to the Estimator. (Default: None).
+ enable_session_tag_chaining (bool or PipelineVariable): Optional.
+ Specifies whether SessionTagChaining is enabled for the training job
Raises:
ValueError: If the model ID is not recognized by JumpStart.
@@ -578,6 +586,8 @@ def _validate_model_id_and_get_type_hook():
disable_output_compression=disable_output_compression,
enable_infra_check=enable_infra_check,
enable_remote_debug=enable_remote_debug,
+ config_name=config_name,
+ enable_session_tag_chaining=enable_session_tag_chaining,
)
self.model_id = estimator_init_kwargs.model_id
@@ -591,6 +601,8 @@ def _validate_model_id_and_get_type_hook():
self.role = estimator_init_kwargs.role
self.sagemaker_session = estimator_init_kwargs.sagemaker_session
self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation
+ self.config_name = estimator_init_kwargs.config_name
+ self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False)
super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict())
@@ -665,6 +677,7 @@ def fit(
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
tolerate_deprecated_model=self.tolerate_deprecated_model,
sagemaker_session=self.sagemaker_session,
+ config_name=self.config_name,
)
return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())
@@ -677,6 +690,7 @@ def attach(
model_version: Optional[str] = None,
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_channel_name: str = "model",
+ config_name: Optional[str] = None,
) -> "JumpStartEstimator":
"""Attach to an existing training job.
@@ -712,6 +726,8 @@ def attach(
model data will be downloaded (default: 'model'). If no channel
with the same name exists in the training job, this option will
be ignored.
+ config_name (str): Optional. Name of the training configuration to use
+ when attaching to the training job. (Default: None).
Returns:
Instance of the calling ``JumpStartEstimator`` Class with the attached
@@ -721,16 +737,23 @@ def attach(
ValueError: if the model ID or version cannot be inferred from the training job.
"""
-
+ config_name = None
if model_id is None:
-
- model_id, model_version = get_model_id_version_from_training_job(
+ model_id, model_version, _, config_name = get_model_info_from_training_job(
training_job_name=training_job_name, sagemaker_session=sagemaker_session
)
model_version = model_version or "*"
- additional_kwargs = {"model_id": model_id, "model_version": model_version}
+ additional_kwargs = {
+ "model_id": model_id,
+ "model_version": model_version,
+ "tolerate_vulnerable_model": True, # model is already trained
+ "tolerate_deprecated_model": True, # model is already trained
+ }
+
+ if config_name:
+ additional_kwargs.update({"config_name": config_name})
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
@@ -740,6 +763,7 @@ def attach(
tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated
tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
# eula was already accepted if the model was successfully trained
@@ -789,6 +813,7 @@ def deploy(
dependencies: Optional[List[str]] = None,
git_config: Optional[Dict[str, str]] = None,
use_compiled_model: bool = False,
+ inference_config_name: Optional[str] = None,
) -> PredictorBase:
"""Creates endpoint from training job.
@@ -1024,6 +1049,8 @@ def deploy(
(Default: None).
use_compiled_model (bool): Flag to select whether to use compiled
(optimized) model. (Default: False).
+ inference_config_name (Optional[str]): Name of the inference configuration to
+ be used in the model. (Default: None).
"""
self.orig_predictor_cls = predictor_cls
@@ -1076,6 +1103,8 @@ def deploy(
git_config=git_config,
use_compiled_model=use_compiled_model,
training_instance_type=self.instance_type,
+ training_config_name=self.config_name,
+ inference_config_name=inference_config_name,
)
predictor = super(JumpStartEstimator, self).deploy(
@@ -1092,11 +1121,43 @@ def deploy(
tolerate_deprecated_model=self.tolerate_deprecated_model,
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
sagemaker_session=self.sagemaker_session,
+ config_name=estimator_deploy_kwargs.config_name,
)
# If a predictor class was passed, do not mutate predictor
return predictor
+ def list_training_configs(self) -> List[JumpStartMetadataConfig]:
+ """Returns a list of configs associated with the estimator.
+
+ Raises:
+ ValueError: If the combination of arguments specified is not supported.
+ VulnerableJumpStartModelError: If any of the dependencies required by the script have
+ known security vulnerabilities.
+ DeprecatedJumpStartModelError: If the version of the model is deprecated.
+ """
+ configs_dict = get_jumpstart_configs(
+ model_id=self.model_id,
+ model_version=self.model_version,
+ model_type=self.model_type,
+ region=self.region,
+ scope=JumpStartScriptScope.TRAINING,
+ sagemaker_session=self.sagemaker_session,
+ )
+ return list(configs_dict.values())
+
+ def set_training_config(self, config_name: str) -> None:
+ """Sets the config to apply to the model.
+
+ Args:
+ config_name (str): The name of the config.
+ """
+ self.__init__(
+ model_id=self.model_id,
+ model_version=self.model_version,
+ config_name=config_name,
+ )
+
def __str__(self) -> str:
"""Overriding str(*) method to make more human-readable."""
return stringify_object(self)
diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py
index 875ec9d003..e171dcd99c 100644
--- a/src/sagemaker/jumpstart/factory/estimator.py
+++ b/src/sagemaker/jumpstart/factory/estimator.py
@@ -61,7 +61,7 @@
JumpStartModelInitKwargs,
)
from sagemaker.jumpstart.utils import (
- add_jumpstart_model_id_version_tags,
+ add_jumpstart_model_info_tags,
get_eula_message,
update_dict_if_key_not_present,
resolve_estimator_sagemaker_config_field,
@@ -130,6 +130,8 @@ def get_init_kwargs(
disable_output_compression: Optional[bool] = None,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
+ config_name: Optional[str] = None,
+ enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
) -> JumpStartEstimatorInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""
@@ -188,6 +190,8 @@ def get_init_kwargs(
disable_output_compression=disable_output_compression,
enable_infra_check=enable_infra_check,
enable_remote_debug=enable_remote_debug,
+ config_name=config_name,
+ enable_session_tag_chaining=enable_session_tag_chaining,
)
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)
@@ -205,6 +209,7 @@ def get_init_kwargs(
estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs)
estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs)
+ estimator_init_kwargs = _add_config_name_to_kwargs(estimator_init_kwargs)
return estimator_init_kwargs
@@ -221,6 +226,7 @@ def get_fit_kwargs(
tolerate_vulnerable_model: Optional[bool] = None,
tolerate_deprecated_model: Optional[bool] = None,
sagemaker_session: Optional[Session] = None,
+ config_name: Optional[str] = None,
) -> JumpStartEstimatorFitKwargs:
"""Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object."""
@@ -236,6 +242,7 @@ def get_fit_kwargs(
tolerate_deprecated_model=tolerate_deprecated_model,
tolerate_vulnerable_model=tolerate_vulnerable_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
estimator_fit_kwargs = _add_model_version_to_kwargs(estimator_fit_kwargs)
@@ -287,6 +294,8 @@ def get_deploy_kwargs(
use_compiled_model: Optional[bool] = None,
model_name: Optional[str] = None,
training_instance_type: Optional[str] = None,
+ training_config_name: Optional[str] = None,
+ inference_config_name: Optional[str] = None,
) -> JumpStartEstimatorDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object."""
@@ -314,6 +323,8 @@ def get_deploy_kwargs(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ training_config_name=training_config_name,
+ config_name=inference_config_name,
)
model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs(
@@ -342,6 +353,7 @@ def get_deploy_kwargs(
tolerate_deprecated_model=tolerate_deprecated_model,
training_instance_type=training_instance_type,
disable_instance_type_logging=True,
+ config_name=model_deploy_kwargs.config_name,
)
estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(
@@ -386,6 +398,7 @@ def get_deploy_kwargs(
tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model,
use_compiled_model=use_compiled_model,
+ config_name=model_deploy_kwargs.config_name,
)
return estimator_deploy_kwargs
@@ -441,6 +454,7 @@ def _add_instance_type_and_count_to_kwargs(
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
)
kwargs.instance_count = kwargs.instance_count or 1
@@ -464,11 +478,16 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
).version
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
- kwargs.tags = add_jumpstart_model_id_version_tags(
- kwargs.tags, kwargs.model_id, full_model_version
+ kwargs.tags = add_jumpstart_model_info_tags(
+ kwargs.tags,
+ kwargs.model_id,
+ full_model_version,
+ config_name=kwargs.config_name,
+ scope=JumpStartScriptScope.TRAINING,
)
return kwargs
@@ -486,6 +505,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
)
return kwargs
@@ -511,6 +531,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
sagemaker_session=kwargs.sagemaker_session,
region=kwargs.region,
instance_type=kwargs.instance_type,
+ config_name=kwargs.config_name,
)
if (
@@ -523,6 +544,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
)
):
JUMPSTART_LOGGER.warning(
@@ -558,6 +580,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
region=kwargs.region,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
)
return kwargs
@@ -578,6 +601,7 @@ def _add_env_to_kwargs(
sagemaker_session=kwargs.sagemaker_session,
script=JumpStartScriptScope.TRAINING,
instance_type=kwargs.instance_type,
+ config_name=kwargs.config_name,
)
model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri(
@@ -588,6 +612,7 @@ def _add_env_to_kwargs(
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
)
if model_package_artifact_uri:
@@ -615,6 +640,7 @@ def _add_env_to_kwargs(
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
)
if model_specs.is_gated_model():
raise ValueError(
@@ -644,9 +670,11 @@ def _add_training_job_name_to_kwargs(
model_id=kwargs.model_id,
model_version=kwargs.model_version,
region=kwargs.region,
+ scope=JumpStartScriptScope.TRAINING,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
)
kwargs.job_name = kwargs.job_name or (
@@ -673,6 +701,7 @@ def _add_hyperparameters_to_kwargs(
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
instance_type=kwargs.instance_type,
+ config_name=kwargs.config_name,
)
for key, value in default_hyperparameters.items():
@@ -706,6 +735,7 @@ def _add_metric_definitions_to_kwargs(
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
instance_type=kwargs.instance_type,
+ config_name=kwargs.config_name,
)
or []
)
@@ -735,6 +765,7 @@ def _add_estimator_extra_kwargs(
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
)
for key, value in estimator_kwargs_to_add.items():
@@ -759,6 +790,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
)
for key, value in fit_kwargs_to_add.items():
@@ -766,3 +798,27 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim
setattr(kwargs, key, value)
return kwargs
+
+
+def _add_config_name_to_kwargs(
+ kwargs: JumpStartEstimatorInitKwargs,
+) -> JumpStartEstimatorInitKwargs:
+ """Sets tags in kwargs based on default or override, returns full kwargs."""
+
+ specs = verify_model_region_and_return_specs(
+ model_id=kwargs.model_id,
+ version=kwargs.model_version,
+ scope=JumpStartScriptScope.TRAINING,
+ region=kwargs.region,
+ tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
+ tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
+ sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
+ )
+
+ if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
+ kwargs.config_name = (
+ kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name
+ )
+
+ return kwargs
diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py
index 28746990e3..0bae5955e2 100644
--- a/src/sagemaker/jumpstart/factory/model.py
+++ b/src/sagemaker/jumpstart/factory/model.py
@@ -42,9 +42,10 @@
JumpStartModelDeployKwargs,
JumpStartModelInitKwargs,
JumpStartModelRegisterKwargs,
+ JumpStartModelSpecs,
)
from sagemaker.jumpstart.utils import (
- add_jumpstart_model_id_version_tags,
+ add_jumpstart_model_info_tags,
update_dict_if_key_not_present,
resolve_model_sagemaker_config_field,
verify_model_region_and_return_specs,
@@ -72,6 +73,7 @@ def get_default_predictor(
tolerate_deprecated_model: bool,
sagemaker_session: Session,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> Predictor:
"""Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one.
@@ -94,6 +96,7 @@ def get_default_predictor(
tolerate_vulnerable_model=tolerate_vulnerable_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
predictor.deserializer = deserializers.retrieve_default(
model_id=model_id,
@@ -103,6 +106,7 @@ def get_default_predictor(
tolerate_vulnerable_model=tolerate_vulnerable_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
predictor.accept = accept_types.retrieve_default(
model_id=model_id,
@@ -112,6 +116,7 @@ def get_default_predictor(
tolerate_vulnerable_model=tolerate_vulnerable_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
predictor.content_type = content_types.retrieve_default(
model_id=model_id,
@@ -121,6 +126,7 @@ def get_default_predictor(
tolerate_vulnerable_model=tolerate_vulnerable_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
return predictor
@@ -184,7 +190,6 @@ def _add_instance_type_to_kwargs(
"""Sets instance type based on default or override, returns full kwargs."""
orig_instance_type = kwargs.instance_type
-
kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default(
region=kwargs.region,
model_id=kwargs.model_id,
@@ -195,6 +200,7 @@ def _add_instance_type_to_kwargs(
sagemaker_session=kwargs.sagemaker_session,
training_instance_type=kwargs.training_instance_type,
model_type=kwargs.model_type,
+ config_name=kwargs.config_name,
)
if not disable_instance_type_logging and orig_instance_type is None:
@@ -226,6 +232,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
)
return kwargs
@@ -247,6 +254,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
instance_type=kwargs.instance_type,
+ config_name=kwargs.config_name,
)
if isinstance(model_data, str) and model_data.startswith("s3://") and model_data.endswith("/"):
@@ -287,6 +295,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
):
source_dir = source_dir or script_uris.retrieve(
script_scope=JumpStartScriptScope.INFERENCE,
@@ -296,6 +305,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
)
kwargs.source_dir = source_dir
@@ -319,6 +329,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
+ config_name=kwargs.config_name,
):
entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME
@@ -350,6 +361,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw
sagemaker_session=kwargs.sagemaker_session,
script=JumpStartScriptScope.INFERENCE,
instance_type=kwargs.instance_type,
+ config_name=kwargs.config_name,
)
for key, value in extra_env_vars.items():
@@ -380,6 +392,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
+ config_name=kwargs.config_name,
)
kwargs.model_package_arn = model_package_arn
@@ -397,6 +410,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
+ config_name=kwargs.config_name,
)
for key, value in model_kwargs_to_add.items():
@@ -433,6 +447,7 @@ def _add_endpoint_name_to_kwargs(
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
+ config_name=kwargs.config_name,
)
kwargs.endpoint_name = kwargs.endpoint_name or (
@@ -455,6 +470,7 @@ def _add_model_name_to_kwargs(
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
+ config_name=kwargs.config_name,
)
kwargs.name = kwargs.name or (
@@ -476,11 +492,17 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]:
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
+ config_name=kwargs.config_name,
).version
if kwargs.sagemaker_session.settings.include_jumpstart_tags:
- kwargs.tags = add_jumpstart_model_id_version_tags(
- kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type
+ kwargs.tags = add_jumpstart_model_info_tags(
+ kwargs.tags,
+ kwargs.model_id,
+ full_model_version,
+ kwargs.model_type,
+ config_name=kwargs.config_name,
+ scope=JumpStartScriptScope.INFERENCE,
)
return kwargs
@@ -498,6 +520,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any]
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
+ config_name=kwargs.config_name,
)
for key, value in deploy_kwargs_to_add.items():
@@ -520,8 +543,106 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
instance_type=kwargs.instance_type,
+ config_name=kwargs.config_name,
+ )
+
+ return kwargs
+
+
+def _select_inference_config_from_training_config(
+ specs: JumpStartModelSpecs, training_config_name: str
+) -> Optional[str]:
+ """Selects the inference config from the training config.
+
+ Args:
+ specs (JumpStartModelSpecs): The specs for the model.
+ training_config_name (str): The name of the training config.
+
+ Returns:
+ str: The name of the inference config.
+ """
+ if specs.training_configs:
+ resolved_training_config = specs.training_configs.configs.get(training_config_name)
+ if resolved_training_config:
+ return resolved_training_config.default_inference_config
+
+ return None
+
+
+def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
+ """Sets default config name to the kwargs. Returns full kwargs.
+
+ Raises:
+ ValueError: If the instance_type is not supported with the current config.
+ """
+
+ specs = verify_model_region_and_return_specs(
+ model_id=kwargs.model_id,
+ version=kwargs.model_version,
+ scope=JumpStartScriptScope.INFERENCE,
+ region=kwargs.region,
+ tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
+ tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
+ sagemaker_session=kwargs.sagemaker_session,
+ model_type=kwargs.model_type,
+ config_name=kwargs.config_name,
+ )
+ if specs.inference_configs:
+ default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
+ kwargs.config_name = kwargs.config_name or default_config_name
+
+ if not kwargs.config_name:
+ return kwargs
+
+ if kwargs.config_name not in set(specs.inference_configs.configs.keys()):
+ raise ValueError(
+ f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}."
+ )
+
+ resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config
+ supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
+ if kwargs.instance_type not in supported_instance_types:
+ raise ValueError(
+ f"Instance type {kwargs.instance_type} "
+ f"is not supported for config {kwargs.config_name}."
+ )
+
+ return kwargs
+
+
+def _add_config_name_to_deploy_kwargs(
+ kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None
+) -> JumpStartModelInitKwargs:
+ """Sets default config name to the kwargs. Returns full kwargs.
+
+ If a training_config_name is passed, then choose the inference config
+ based on the supported inference configs in that training config.
+
+ Raises:
+ ValueError: If the instance_type is not supported with the current config.
+ """
+
+ specs = verify_model_region_and_return_specs(
+ model_id=kwargs.model_id,
+ version=kwargs.model_version,
+ scope=JumpStartScriptScope.INFERENCE,
+ region=kwargs.region,
+ tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
+ tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
+ sagemaker_session=kwargs.sagemaker_session,
+ model_type=kwargs.model_type,
+ config_name=kwargs.config_name,
)
+ if training_config_name:
+ kwargs.config_name = _select_inference_config_from_training_config(
+ specs=specs, training_config_name=training_config_name
+ )
+
+ if specs.inference_configs:
+ default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
+ kwargs.config_name = kwargs.config_name or default_config_name
+
return kwargs
@@ -555,6 +676,9 @@ def get_deploy_kwargs(
resources: Optional[ResourceRequirements] = None,
managed_instance_scaling: Optional[str] = None,
endpoint_type: Optional[EndpointType] = None,
+ training_config_name: Optional[str] = None,
+ config_name: Optional[str] = None,
+ routing_config: Optional[Dict[str, Any]] = None,
) -> JumpStartModelDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""
@@ -586,6 +710,8 @@ def get_deploy_kwargs(
accept_eula=accept_eula,
endpoint_logging=endpoint_logging,
resources=resources,
+ config_name=config_name,
+ routing_config=routing_config,
)
deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs)
@@ -594,6 +720,10 @@ def get_deploy_kwargs(
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
+ deploy_kwargs = _add_config_name_to_deploy_kwargs(
+ kwargs=deploy_kwargs, training_config_name=training_config_name
+ )
+
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
deploy_kwargs.initial_instance_count = initial_instance_count or 1
@@ -639,6 +769,7 @@ def get_register_kwargs(
data_input_configuration: Optional[str] = None,
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
+ config_name: Optional[str] = None,
) -> JumpStartModelRegisterKwargs:
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""
@@ -681,6 +812,7 @@ def get_register_kwargs(
sagemaker_session=sagemaker_session,
tolerate_deprecated_model=tolerate_deprecated_model,
tolerate_vulnerable_model=tolerate_vulnerable_model,
+ config_name=config_name,
)
register_kwargs.content_types = (
@@ -723,6 +855,7 @@ def get_init_kwargs(
training_instance_type: Optional[str] = None,
disable_instance_type_logging: bool = False,
resources: Optional[ResourceRequirements] = None,
+ config_name: Optional[str] = None,
) -> JumpStartModelInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
@@ -754,6 +887,7 @@ def get_init_kwargs(
model_package_arn=model_package_arn,
training_instance_type=training_instance_type,
resources=resources,
+ config_name=config_name,
)
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
@@ -784,4 +918,6 @@ def get_init_kwargs(
model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
+ model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
+
return model_init_kwargs
diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py
index 4529bc11b9..f72a3140dc 100644
--- a/src/sagemaker/jumpstart/model.py
+++ b/src/sagemaker/jumpstart/model.py
@@ -14,7 +14,8 @@
from __future__ import absolute_import
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Optional, Any, Union
+import pandas as pd
from botocore.exceptions import ClientError
from sagemaker import payloads
@@ -36,10 +37,18 @@
get_init_kwargs,
get_register_kwargs,
)
-from sagemaker.jumpstart.types import JumpStartSerializablePayload
+from sagemaker.jumpstart.types import (
+ JumpStartSerializablePayload,
+ DeploymentConfigMetadata,
+)
from sagemaker.jumpstart.utils import (
validate_model_id_and_get_type,
verify_model_region_and_return_specs,
+ get_jumpstart_configs,
+ get_metrics_from_deployment_configs,
+ add_instance_rate_stats_to_benchmark_metrics,
+ deployment_config_response_data,
+ _deployment_config_lru_cache,
)
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
from sagemaker.jumpstart.enums import JumpStartModelType
@@ -92,6 +101,7 @@ def __init__(
git_config: Optional[Dict[str, str]] = None,
model_package_arn: Optional[str] = None,
resources: Optional[ResourceRequirements] = None,
+ config_name: Optional[str] = None,
):
"""Initializes a ``JumpStartModel``.
@@ -277,6 +287,8 @@ def __init__(
for a model to be deployed to an endpoint.
Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature.
(Default: None).
+ config_name (Optional[str]): The name of the JumpStartConfig that can be
+ optionally applied to the model and override corresponding fields.
Raises:
ValueError: If the model ID is not recognized by JumpStart.
"""
@@ -326,6 +338,7 @@ def _validate_model_id_and_type():
git_config=git_config,
model_package_arn=model_package_arn,
resources=resources,
+ config_name=config_name,
)
self.orig_predictor_cls = predictor_cls
@@ -338,6 +351,7 @@ def _validate_model_id_and_type():
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
self.region = model_init_kwargs.region
self.sagemaker_session = model_init_kwargs.sagemaker_session
+ self.config_name = model_init_kwargs.config_name
if self.model_type == JumpStartModelType.PROPRIETARY:
self.log_subscription_warning()
@@ -345,6 +359,15 @@ def _validate_model_id_and_type():
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
self.model_package_arn = model_init_kwargs.model_package_arn
+ self.init_kwargs = model_init_kwargs.to_kwargs_dict(False)
+
+ self._metadata_configs = get_jumpstart_configs(
+ region=self.region,
+ model_id=self.model_id,
+ model_version=self.model_version,
+ sagemaker_session=self.sagemaker_session,
+ model_type=self.model_type,
+ )
def log_subscription_warning(self) -> None:
"""Log message prompting the customer to subscribe to the proprietary model."""
@@ -402,6 +425,70 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload:
sagemaker_session=self.sagemaker_session,
)
+ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
+ """Sets the deployment config to apply to the model.
+
+ Args:
+ config_name (str):
+ The name of the deployment config to apply to the model.
+ Call list_deployment_configs to see the list of config names.
+ instance_type (str):
+ The instance_type that the model will use after setting
+ the config.
+ """
+ self.__init__(
+ model_id=self.model_id,
+ model_version=self.model_version,
+ instance_type=instance_type,
+ config_name=config_name,
+ )
+
+ @property
+ def deployment_config(self) -> Optional[Dict[str, Any]]:
+ """The deployment config that will be applied to ``This`` model.
+
+ Returns:
+ Optional[Dict[str, Any]]: Deployment config.
+ """
+ if self.config_name is None:
+ return None
+ for config in self.list_deployment_configs():
+ if config.get("DeploymentConfigName") == self.config_name:
+ return config
+ return None
+
+ @property
+ def benchmark_metrics(self) -> pd.DataFrame:
+ """Benchmark Metrics for deployment configs.
+
+ Returns:
+ Benchmark Metrics: Pandas DataFrame object.
+ """
+ df = pd.DataFrame(self._get_deployment_configs_benchmarks_data())
+ blank_index = [""] * len(df)
+ df.index = blank_index
+ return df
+
+ def display_benchmark_metrics(self, **kwargs) -> None:
+ """Display deployment configs benchmark metrics."""
+ df = self.benchmark_metrics
+
+ instance_type = kwargs.get("instance_type")
+ if instance_type:
+ df = df[df["Instance Type"].str.contains(instance_type)]
+
+ print(df.to_markdown(index=False, floatfmt=".2f"))
+
+ def list_deployment_configs(self) -> List[Dict[str, Any]]:
+ """List deployment configs for ``This`` model.
+
+ Returns:
+ List[Dict[str, Any]]: A list of deployment configs.
+ """
+ return deployment_config_response_data(
+ self._get_deployment_configs(self.config_name, self.instance_type)
+ )
+
def _create_sagemaker_model(
self,
instance_type=None,
@@ -496,6 +583,7 @@ def deploy(
resources: Optional[ResourceRequirements] = None,
managed_instance_scaling: Optional[str] = None,
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
+ routing_config: Optional[Dict[str, Any]] = None,
) -> PredictorBase:
"""Creates endpoint by calling base ``Model`` class `deploy` method.
@@ -590,6 +678,8 @@ def deploy(
endpoint.
endpoint_type (EndpointType): The type of endpoint used to deploy models.
(Default: EndpointType.MODEL_BASED).
+ routing_config (Optional[Dict]): Settings the control how the endpoint routes
+ incoming traffic to the instances that the endpoint hosts.
Raises:
MarketplaceModelSubscriptionError: If the caller is not subscribed to the model.
@@ -625,6 +715,8 @@ def deploy(
managed_instance_scaling=managed_instance_scaling,
endpoint_type=endpoint_type,
model_type=self.model_type,
+ config_name=self.config_name,
+ routing_config=routing_config,
)
if (
self.model_type == JumpStartModelType.PROPRIETARY
@@ -644,6 +736,7 @@ def deploy(
model_type=self.model_type,
scope=JumpStartScriptScope.INFERENCE,
sagemaker_session=self.sagemaker_session,
+ config_name=self.config_name,
).model_subscription_link
get_proprietary_model_subscription_error(e, subscription_link)
raise
@@ -659,6 +752,7 @@ def deploy(
tolerate_vulnerable_model=self.tolerate_vulnerable_model,
sagemaker_session=self.sagemaker_session,
model_type=self.model_type,
+ config_name=self.config_name,
)
# If a predictor class was passed, do not mutate predictor
@@ -769,6 +863,7 @@ def register(
data_input_configuration=data_input_configuration,
skip_model_validation=skip_model_validation,
source_uri=source_uri,
+ config_name=self.config_name,
)
model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())
@@ -786,6 +881,89 @@ def register_deploy_wrapper(*args, **kwargs):
return model_package
+ @_deployment_config_lru_cache
+ def _get_deployment_configs_benchmarks_data(self) -> Dict[str, Any]:
+ """Deployment configs benchmark metrics.
+
+ Returns:
+ Dict[str, List[str]]: Deployment config benchmark data.
+ """
+ return get_metrics_from_deployment_configs(
+ self._get_deployment_configs(None, None),
+ )
+
+ @_deployment_config_lru_cache
+ def _get_deployment_configs(
+ self, selected_config_name: Optional[str], selected_instance_type: Optional[str]
+ ) -> List[DeploymentConfigMetadata]:
+ """Retrieve deployment configs metadata.
+
+ Args:
+ selected_config_name (Optional[str]): The name of the selected deployment config.
+ selected_instance_type (Optional[str]): The selected instance type.
+ """
+ deployment_configs = []
+ if not self._metadata_configs:
+ return deployment_configs
+
+ err = None
+ for config_name, metadata_config in self._metadata_configs.items():
+ if selected_config_name == config_name:
+ instance_type_to_use = selected_instance_type
+ else:
+ instance_type_to_use = metadata_config.resolved_config.get(
+ "default_inference_instance_type"
+ )
+
+ if metadata_config.benchmark_metrics:
+ err, metadata_config.benchmark_metrics = (
+ add_instance_rate_stats_to_benchmark_metrics(
+ self.region, metadata_config.benchmark_metrics
+ )
+ )
+
+ config_components = metadata_config.config_components.get(config_name)
+ image_uri = (
+ (
+ config_components.hosting_instance_type_variants.get("regional_aliases", {})
+ .get(self.region, {})
+ .get("alias_ecr_uri_1")
+ )
+ if config_components
+ else self.image_uri
+ )
+
+ init_kwargs = get_init_kwargs(
+ config_name=config_name,
+ model_id=self.model_id,
+ instance_type=instance_type_to_use,
+ sagemaker_session=self.sagemaker_session,
+ image_uri=image_uri,
+ region=self.region,
+ model_version=self.model_version,
+ )
+ deploy_kwargs = get_deploy_kwargs(
+ model_id=self.model_id,
+ instance_type=instance_type_to_use,
+ sagemaker_session=self.sagemaker_session,
+ region=self.region,
+ model_version=self.model_version,
+ )
+
+ deployment_config_metadata = DeploymentConfigMetadata(
+ config_name,
+ metadata_config,
+ init_kwargs,
+ deploy_kwargs,
+ )
+ deployment_configs.append(deployment_config_metadata)
+
+ if err and err["Code"] == "AccessDeniedException":
+ error_message = "Instance rate metrics will be omitted. Reason: %s"
+ JUMPSTART_LOGGER.warning(error_message, err["Message"])
+
+ return deployment_configs
+
def __str__(self) -> str:
"""Overriding str(*) method to make more human-readable."""
return stringify_object(self)
diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py
index 732493ce3b..781548b42a 100644
--- a/src/sagemaker/jumpstart/notebook_utils.py
+++ b/src/sagemaker/jumpstart/notebook_utils.py
@@ -329,9 +329,12 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
return sorted(list(model_id_version_dict.keys()))
if not list_old_models:
- model_id_version_dict = {
- model_id: set([max(versions)]) for model_id, versions in model_id_version_dict.items()
- }
+ for model_id, versions in model_id_version_dict.items():
+ try:
+ model_id_version_dict.update({model_id: set([max(versions)])})
+ except TypeError:
+ versions = [str(v) for v in versions]
+ model_id_version_dict.update({model_id: set([max(versions)])})
model_id_version_set: Set[Tuple[str, str]] = set()
for model_id in model_id_version_dict:
@@ -532,6 +535,7 @@ def get_model_url(
model_version: str,
region: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> str:
"""Retrieve web url describing pretrained model.
@@ -560,5 +564,6 @@ def get_model_url(
sagemaker_session=sagemaker_session,
scope=JumpStartScriptScope.INFERENCE,
model_type=model_type,
+ config_name=config_name,
)
return model_specs.url
diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py
index 595f801598..9c6716dc64 100644
--- a/src/sagemaker/jumpstart/payload_utils.py
+++ b/src/sagemaker/jumpstart/payload_utils.py
@@ -23,7 +23,7 @@
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
-from sagemaker.jumpstart.enums import MIMEType
+from sagemaker.jumpstart.enums import JumpStartModelType, MIMEType
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
@@ -61,6 +61,8 @@ def _construct_payload(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ alias: Optional[str] = None,
) -> Optional[JumpStartSerializablePayload]:
"""Returns example payload from prompt.
@@ -83,6 +85,8 @@ def _construct_payload(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ model_type (JumpStartModelType): The type of the model, can be open weights model or
+ proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if
this feature is unavailable for the specified model.
@@ -94,11 +98,14 @@ def _construct_payload(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ model_type=model_type,
)
if payloads is None or len(payloads) == 0:
return None
- payload_to_use: JumpStartSerializablePayload = list(payloads.values())[0]
+ payload_to_use: JumpStartSerializablePayload = (
+ payloads[alias] if alias else list(payloads.values())[0]
+ )
prompt_key: Optional[str] = payload_to_use.prompt_key
if prompt_key is None:
diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py
index e511a052d1..0955ae9480 100644
--- a/src/sagemaker/jumpstart/session_utils.py
+++ b/src/sagemaker/jumpstart/session_utils.py
@@ -17,17 +17,17 @@
from typing import Optional, Tuple
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
-from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn
+from sagemaker.jumpstart.utils import get_jumpstart_model_info_from_resource_arn
from sagemaker.session import Session
from sagemaker.utils import aws_partition
-def get_model_id_version_from_endpoint(
+def get_model_info_from_endpoint(
endpoint_name: str,
inference_component_name: Optional[str] = None,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
-) -> Tuple[str, str, Optional[str]]:
- """Given an endpoint and optionally inference component names, return the model ID and version.
+) -> Tuple[str, str, Optional[str], Optional[str], Optional[str]]:
+ """Optionally inference component names, return the model ID, version and config name.
Infers the model ID and version based on the resource tags. Returns a tuple of the model ID
and version. A third string element is included in the tuple for any inferred inference
@@ -46,7 +46,9 @@ def get_model_id_version_from_endpoint(
(
model_id,
model_version,
- ) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
+ inference_config_name,
+ training_config_name,
+ ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301
inference_component_name, sagemaker_session
)
@@ -54,22 +56,35 @@ def get_model_id_version_from_endpoint(
(
model_id,
model_version,
+ inference_config_name,
+ training_config_name,
inference_component_name,
- ) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
+ ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301
endpoint_name, sagemaker_session
)
else:
- model_id, model_version = _get_model_id_version_from_model_based_endpoint(
+ (
+ model_id,
+ model_version,
+ inference_config_name,
+ training_config_name,
+ ) = _get_model_info_from_model_based_endpoint(
endpoint_name, inference_component_name, sagemaker_session
)
- return model_id, model_version, inference_component_name
+ return (
+ model_id,
+ model_version,
+ inference_component_name,
+ inference_config_name,
+ training_config_name,
+ )
-def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name(
+def _get_model_info_from_inference_component_endpoint_without_inference_component_name(
endpoint_name: str, sagemaker_session: Session
-) -> Tuple[str, str, str]:
- """Given an endpoint name, derives the model ID, version, and inferred inference component name.
+) -> Tuple[str, str, str, str]:
+ """Derives the model ID, version, config name and inferred inference component name.
This function assumes the endpoint corresponds to an inference-component-based endpoint.
An endpoint is inference-component-based if and only if the associated endpoint config
@@ -98,14 +113,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co
)
inference_component_name = inference_component_names[0]
return (
- *_get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
+ *_get_model_info_from_inference_component_endpoint_with_inference_component_name(
inference_component_name, sagemaker_session
),
inference_component_name,
)
-def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
+def _get_model_info_from_inference_component_endpoint_with_inference_component_name(
inference_component_name: str, sagemaker_session: Session
):
"""Returns the model ID and version inferred from a SageMaker inference component.
@@ -123,9 +138,12 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
f"inference-component/{inference_component_name}"
)
- model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
- inference_component_arn, sagemaker_session
- )
+ (
+ model_id,
+ model_version,
+ inference_config_name,
+ training_config_name,
+ ) = get_jumpstart_model_info_from_resource_arn(inference_component_arn, sagemaker_session)
if not model_id:
raise ValueError(
@@ -134,15 +152,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo
"when retrieving default predictor for this inference component."
)
- return model_id, model_version
+ return model_id, model_version, inference_config_name, training_config_name
-def _get_model_id_version_from_model_based_endpoint(
+def _get_model_info_from_model_based_endpoint(
endpoint_name: str,
inference_component_name: Optional[str],
sagemaker_session: Session,
-) -> Tuple[str, str]:
- """Returns the model ID and version inferred from a model-based endpoint.
+) -> Tuple[str, str, Optional[str], Optional[str]]:
+ """Returns the model ID, version and config name inferred from a model-based endpoint.
Raises:
ValueError: If an inference component name is supplied, or if the endpoint does
@@ -161,9 +179,12 @@ def _get_model_id_version_from_model_based_endpoint(
endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}"
- model_id, model_version = get_jumpstart_model_id_version_from_resource_arn(
- endpoint_arn, sagemaker_session
- )
+ (
+ model_id,
+ model_version,
+ inference_config_name,
+ training_config_name,
+ ) = get_jumpstart_model_info_from_resource_arn(endpoint_arn, sagemaker_session)
if not model_id:
raise ValueError(
@@ -172,14 +193,14 @@ def _get_model_id_version_from_model_based_endpoint(
"predictor for this endpoint."
)
- return model_id, model_version
+ return model_id, model_version, inference_config_name, training_config_name
-def get_model_id_version_from_training_job(
+def get_model_info_from_training_job(
training_job_name: str,
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
-) -> Tuple[str, str]:
- """Returns the model ID and version inferred from a training job.
+) -> Tuple[str, str, Optional[str], Optional[str]]:
+ """Returns the model ID and version and config name inferred from a training job.
Raises:
ValueError: If the training job does not have tags from which the model ID
@@ -194,9 +215,12 @@ def get_model_id_version_from_training_job(
f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}"
)
- model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn(
- training_job_arn, sagemaker_session
- )
+ (
+ model_id,
+ inferred_model_version,
+ inference_config_name,
+ training_config_name,
+ ) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session)
model_version = inferred_model_version or None
@@ -207,4 +231,4 @@ def get_model_id_version_from_training_job(
"for this training job."
)
- return model_id, model_version
+ return model_id, model_version, inference_config_name, training_config_name
diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py
index 8e53bf6f83..f197421d65 100644
--- a/src/sagemaker/jumpstart/types.py
+++ b/src/sagemaker/jumpstart/types.py
@@ -744,12 +744,12 @@ def _get_regional_property(
class JumpStartBenchmarkStat(JumpStartDataHolderType):
- """Data class JumpStart benchmark stats."""
+ """Data class JumpStart benchmark stat."""
- __slots__ = ["name", "value", "unit"]
+ __slots__ = ["name", "value", "unit", "concurrency"]
def __init__(self, spec: Dict[str, Any]):
- """Initializes a JumpStartBenchmarkStat object
+ """Initializes a JumpStartBenchmarkStat object.
Args:
spec (Dict[str, Any]): Dictionary representation of benchmark stat.
@@ -765,6 +765,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.name: str = json_obj["name"]
self.value: str = json_obj["value"]
self.unit: Union[int, str] = json_obj["unit"]
+ self.concurrency: Union[int, str] = json_obj["concurrency"]
def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartBenchmarkStat object."""
@@ -858,7 +859,7 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType):
"model_subscription_link",
]
- def __init__(self, fields: Optional[Dict[str, Any]]):
+ def __init__(self, fields: Dict[str, Any]):
"""Initializes a JumpStartMetadataFields object.
Args:
@@ -877,7 +878,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.version: str = json_obj.get("version")
self.min_sdk_version: str = json_obj.get("min_sdk_version")
self.incremental_training_supported: bool = bool(
- json_obj.get("incremental_training_supported")
+ json_obj.get("incremental_training_supported", False)
)
self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = (
JumpStartECRSpecs(json_obj["hosting_ecr_specs"])
@@ -950,7 +951,10 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key")
- self.hosting_model_package_arns: Optional[Dict] = json_obj.get("hosting_model_package_arns")
+ model_package_arns = json_obj.get("hosting_model_package_arns")
+ self.hosting_model_package_arns: Optional[Dict] = (
+ model_package_arns if model_package_arns is not None else {}
+ )
self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True)
self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
@@ -1038,7 +1042,7 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields):
__slots__ = slots + JumpStartMetadataBaseFields.__slots__
- def __init__( # pylint: disable=super-init-not-called
+ def __init__(
self,
component_name: str,
component: Optional[Dict[str, Any]],
@@ -1049,7 +1053,10 @@ def __init__( # pylint: disable=super-init-not-called
component_name (str): Name of the component.
component (Dict[str, Any]):
Dictionary representation of the config component.
+ Raises:
+ ValueError: If the component field is invalid.
"""
+ super().__init__(component)
self.component_name = component_name
self.from_json(component)
@@ -1061,9 +1068,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
Dictionary representation of the config component.
"""
for field in json_obj.keys():
- if field not in self.__slots__:
- raise ValueError(f"Invalid component field: {field}")
- setattr(self, field, json_obj[field])
+ if field in self.__slots__:
+ setattr(self, field, json_obj[field])
class JumpStartMetadataConfig(JumpStartDataHolderType):
@@ -1072,30 +1078,57 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
__slots__ = [
"base_fields",
"benchmark_metrics",
+ "acceleration_configs",
"config_components",
"resolved_metadata_config",
+ "config_name",
+ "default_inference_config",
+ "default_incremental_training_config",
+ "supported_inference_configs",
+ "supported_incremental_training_configs",
]
def __init__(
self,
+ config_name: str,
+ config: Dict[str, Any],
base_fields: Dict[str, Any],
config_components: Dict[str, JumpStartConfigComponent],
- benchmark_metrics: Dict[str, JumpStartBenchmarkStat],
):
"""Initializes a JumpStartMetadataConfig object from its json representation.
Args:
+ config_name (str): Name of the config,
+ config (Dict[str, Any]):
+ Dictionary representation of the config.
base_fields (Dict[str, Any]):
- The default base fields that are used to construct the final resolved config.
+ The default base fields that are used to construct the resolved config.
config_components (Dict[str, JumpStartConfigComponent]):
The list of components that are used to construct the resolved config.
- benchmark_metrics (Dict[str, JumpStartBenchmarkStat]):
- The dictionary of benchmark metrics with name being the key.
"""
self.base_fields = base_fields
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
- self.benchmark_metrics: Dict[str, JumpStartBenchmarkStat] = benchmark_metrics
+ self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = (
+ {
+ stat_name: [JumpStartBenchmarkStat(stat) for stat in stats]
+ for stat_name, stats in config.get("benchmark_metrics").items()
+ }
+ if config and config.get("benchmark_metrics")
+ else None
+ )
+ self.acceleration_configs = config.get("acceleration_configs")
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
+ self.config_name: Optional[str] = config_name
+ self.default_inference_config: Optional[str] = config.get("default_inference_config")
+ self.default_incremental_training_config: Optional[str] = config.get(
+ "default_incremental_training_config"
+ )
+ self.supported_inference_configs: Optional[List[str]] = config.get(
+ "supported_inference_configs"
+ )
+ self.supported_incremental_training_configs: Optional[List[str]] = config.get(
+ "supported_incremental_training_configs"
+ )
def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartMetadataConfig object."""
@@ -1104,7 +1137,7 @@ def to_json(self) -> Dict[str, Any]:
@property
def resolved_config(self) -> Dict[str, Any]:
- """Returns the final config that is resolved from the list of components.
+ """Returns the final config that is resolved from the components map.
Construct the final config by applying the list of configs from list index,
and apply to the base default fields in the current model specs.
@@ -1119,6 +1152,12 @@ def resolved_config(self) -> Dict[str, Any]:
deepcopy(component.to_json()),
component.OVERRIDING_DENY_LIST,
)
+
+ # Remove environment variables from resolved config if using model packages
+ hosting_model_pacakge_arns = resolved_config.get("hosting_model_package_arns")
+ if hosting_model_pacakge_arns is not None and hosting_model_pacakge_arns != {}:
+ resolved_config["inference_environment_variables"] = []
+
self.resolved_metadata_config = resolved_config
return resolved_config
@@ -1139,7 +1178,7 @@ def __init__(
Args:
configs (Dict[str, JumpStartMetadataConfig]):
- List of configs that the current model has.
+ The map of JumpStartMetadataConfig object, with config name being the key.
config_rankings (JumpStartConfigRanking):
Config ranking class represents the ranking of the configs in the model.
scope (JumpStartScriptScope):
@@ -1158,22 +1197,36 @@ def get_top_config_from_ranking(
self,
ranking_name: str = JumpStartConfigRankingName.DEFAULT,
instance_type: Optional[str] = None,
- ) -> JumpStartMetadataConfig:
- """Gets the best the config based on config ranking."""
- if self.configs and (
- not self.config_rankings or not self.config_rankings.get(ranking_name)
- ):
- raise ValueError("Config exists but missing config ranking.")
+ ) -> Optional[JumpStartMetadataConfig]:
+ """Gets the best the config based on config ranking.
+
+ Fallback to use the ordering in config names if
+ ranking is not available.
+ Args:
+ ranking_name (str):
+ The ranking name that config priority is based on.
+ instance_type (Optional[str]):
+ The instance type which the config selection is based on.
+
+ Raises:
+ NotImplementedError: If the scope is unrecognized.
+ """
if self.scope == JumpStartScriptScope.INFERENCE:
instance_type_attribute = "supported_inference_instance_types"
elif self.scope == JumpStartScriptScope.TRAINING:
instance_type_attribute = "supported_training_instance_types"
else:
- raise ValueError(f"Unknown script scope {self.scope}")
+ raise NotImplementedError(f"Unknown script scope {self.scope}")
- rankings = self.config_rankings.get(ranking_name)
- for config_name in rankings.rankings:
+ if self.configs and (
+ not self.config_rankings or not self.config_rankings.get(ranking_name)
+ ):
+ ranked_config_names = sorted(list(self.configs.keys()))
+ else:
+ rankings = self.config_rankings.get(ranking_name)
+ ranked_config_names = rankings.rankings
+ for config_name in ranked_config_names:
resolved_config = self.configs[config_name].resolved_config
if instance_type and instance_type not in getattr(
resolved_config, instance_type_attribute
@@ -1198,12 +1251,13 @@ class JumpStartModelSpecs(JumpStartMetadataBaseFields):
__slots__ = JumpStartMetadataBaseFields.__slots__ + slots
- def __init__(self, spec: Dict[str, Any]): # pylint: disable=super-init-not-called
+ def __init__(self, spec: Dict[str, Any]):
"""Initializes a JumpStartModelSpecs object from its json representation.
Args:
spec (Dict[str, Any]): Dictionary representation of spec.
"""
+ super().__init__(spec)
self.from_json(spec)
if self.inference_configs and self.inference_configs.get_top_config_from_ranking():
super().from_json(self.inference_configs.get_top_config_from_ranking().resolved_config)
@@ -1234,6 +1288,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
+ alias,
+ config,
json_obj,
(
{
@@ -1243,14 +1299,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if config and config.get("component_names")
else None
),
- (
- {
- stat_name: JumpStartBenchmarkStat(stat)
- for stat_name, stat in config.get("benchmark_metrics").items()
- }
- if config and config.get("benchmark_metrics")
- else None
- ),
)
for alias, config in json_obj["inference_configs"].items()
}
@@ -1286,6 +1334,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
+ alias,
+ config,
json_obj,
(
{
@@ -1295,14 +1345,6 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if config and config.get("component_names")
else None
),
- (
- {
- stat_name: JumpStartBenchmarkStat(stat)
- for stat_name, stat in config.get("benchmark_metrics").items()
- }
- if config and config.get("benchmark_metrics")
- else None
- ),
)
for alias, config in json_obj["training_configs"].items()
}
@@ -1330,13 +1372,26 @@ def set_config(
config_name (str): Name of the config.
scope (JumpStartScriptScope, optional):
Scope of the config. Defaults to JumpStartScriptScope.INFERENCE.
+
+ Raises:
+ ValueError: If the scope is not supported, or cannot find config name.
"""
if scope == JumpStartScriptScope.INFERENCE:
- super().from_json(self.inference_configs.configs[config_name].resolved_config)
+ metadata_configs = self.inference_configs
elif scope == JumpStartScriptScope.TRAINING and self.training_supported:
- super().from_json(self.training_configs.configs[config_name].resolved_config)
+ metadata_configs = self.training_configs
else:
- raise ValueError(f"Unknown Jumpstart Script scope {scope}.")
+ raise ValueError(f"Unknown Jumpstart script scope {scope}.")
+
+ config_object = metadata_configs.configs.get(config_name)
+ if not config_object:
+ error_msg = f"Cannot find Jumpstart config name {config_name}. "
+ config_names = list(metadata_configs.configs.keys())
+ if config_names:
+ error_msg += f"List of config names that is supported by the model: {config_names}"
+ raise ValueError(error_msg)
+
+ super().from_json(config_object.resolved_config)
def supports_prepacked_inference(self) -> bool:
"""Returns True if the model has a prepacked inference artifact."""
@@ -1437,11 +1492,11 @@ class JumpStartKwargs(JumpStartDataHolderType):
SERIALIZATION_EXCLUSION_SET: Set[str] = set()
- def to_kwargs_dict(self):
+ def to_kwargs_dict(self, exclude_keys: bool = True):
"""Serializes object to dictionary to be used for kwargs for method arguments."""
kwargs_dict = {}
for field in self.__slots__:
- if field not in self.SERIALIZATION_EXCLUSION_SET:
+ if exclude_keys and field not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys:
att_value = getattr(self, field)
if att_value is not None:
kwargs_dict[field] = getattr(self, field)
@@ -1479,6 +1534,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
"model_package_arn",
"training_instance_type",
"resources",
+ "config_name",
]
SERIALIZATION_EXCLUSION_SET = {
@@ -1491,6 +1547,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
"region",
"model_package_arn",
"training_instance_type",
+ "config_name",
}
def __init__(
@@ -1522,6 +1579,7 @@ def __init__(
model_package_arn: Optional[str] = None,
training_instance_type: Optional[str] = None,
resources: Optional[ResourceRequirements] = None,
+ config_name: Optional[str] = None,
) -> None:
"""Instantiates JumpStartModelInitKwargs object."""
@@ -1552,6 +1610,7 @@ def __init__(
self.model_package_arn = model_package_arn
self.training_instance_type = training_instance_type
self.resources = resources
+ self.config_name = config_name
class JumpStartModelDeployKwargs(JumpStartKwargs):
@@ -1587,6 +1646,8 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"endpoint_logging",
"resources",
"endpoint_type",
+ "config_name",
+ "routing_config",
]
SERIALIZATION_EXCLUSION_SET = {
@@ -1598,6 +1659,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"tolerate_vulnerable_model",
"sagemaker_session",
"training_instance_type",
+ "config_name",
}
def __init__(
@@ -1631,6 +1693,8 @@ def __init__(
endpoint_logging: Optional[bool] = None,
resources: Optional[ResourceRequirements] = None,
endpoint_type: Optional[EndpointType] = None,
+ config_name: Optional[str] = None,
+ routing_config: Optional[Dict[str, Any]] = None,
) -> None:
"""Instantiates JumpStartModelDeployKwargs object."""
@@ -1663,6 +1727,8 @@ def __init__(
self.endpoint_logging = endpoint_logging
self.resources = resources
self.endpoint_type = endpoint_type
+ self.config_name = config_name
+ self.routing_config = routing_config
class JumpStartEstimatorInitKwargs(JumpStartKwargs):
@@ -1723,6 +1789,8 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"disable_output_compression",
"enable_infra_check",
"enable_remote_debug",
+ "config_name",
+ "enable_session_tag_chaining",
]
SERIALIZATION_EXCLUSION_SET = {
@@ -1732,6 +1800,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"model_id",
"model_version",
"model_type",
+ "config_name",
}
def __init__(
@@ -1790,6 +1859,8 @@ def __init__(
disable_output_compression: Optional[bool] = None,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
+ config_name: Optional[str] = None,
+ enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""
@@ -1849,6 +1920,8 @@ def __init__(
self.disable_output_compression = disable_output_compression
self.enable_infra_check = enable_infra_check
self.enable_remote_debug = enable_remote_debug
+ self.config_name = config_name
+ self.enable_session_tag_chaining = enable_session_tag_chaining
class JumpStartEstimatorFitKwargs(JumpStartKwargs):
@@ -1867,6 +1940,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
"tolerate_deprecated_model",
"tolerate_vulnerable_model",
"sagemaker_session",
+ "config_name",
]
SERIALIZATION_EXCLUSION_SET = {
@@ -1877,6 +1951,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
"tolerate_deprecated_model",
"tolerate_vulnerable_model",
"sagemaker_session",
+ "config_name",
}
def __init__(
@@ -1893,6 +1968,7 @@ def __init__(
tolerate_deprecated_model: Optional[bool] = None,
tolerate_vulnerable_model: Optional[bool] = None,
sagemaker_session: Optional[Session] = None,
+ config_name: Optional[str] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""
@@ -1908,6 +1984,7 @@ def __init__(
self.tolerate_deprecated_model = tolerate_deprecated_model
self.tolerate_vulnerable_model = tolerate_vulnerable_model
self.sagemaker_session = sagemaker_session
+ self.config_name = config_name
class JumpStartEstimatorDeployKwargs(JumpStartKwargs):
@@ -1953,6 +2030,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs):
"tolerate_vulnerable_model",
"model_name",
"use_compiled_model",
+ "config_name",
]
SERIALIZATION_EXCLUSION_SET = {
@@ -1962,6 +2040,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs):
"model_id",
"model_version",
"sagemaker_session",
+ "config_name",
}
def __init__(
@@ -2005,6 +2084,7 @@ def __init__(
tolerate_deprecated_model: Optional[bool] = None,
tolerate_vulnerable_model: Optional[bool] = None,
use_compiled_model: bool = False,
+ config_name: Optional[str] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""
@@ -2047,6 +2127,7 @@ def __init__(
self.tolerate_deprecated_model = tolerate_deprecated_model
self.tolerate_vulnerable_model = tolerate_vulnerable_model
self.use_compiled_model = use_compiled_model
+ self.config_name = config_name
class JumpStartModelRegisterKwargs(JumpStartKwargs):
@@ -2081,6 +2162,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
"data_input_configuration",
"skip_model_validation",
"source_uri",
+ "config_name",
]
SERIALIZATION_EXCLUSION_SET = {
@@ -2090,6 +2172,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
"model_id",
"model_version",
"sagemaker_session",
+ "config_name",
}
def __init__(
@@ -2122,6 +2205,7 @@ def __init__(
data_input_configuration: Optional[str] = None,
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
+ config_name: Optional[str] = None,
) -> None:
"""Instantiates JumpStartModelRegisterKwargs object."""
@@ -2154,3 +2238,124 @@ def __init__(
self.data_input_configuration = data_input_configuration
self.skip_model_validation = skip_model_validation
self.source_uri = source_uri
+ self.config_name = config_name
+
+
+class BaseDeploymentConfigDataHolder(JumpStartDataHolderType):
+ """Base class for Deployment Config Data."""
+
+ def _convert_to_pascal_case(self, attr_name: str) -> str:
+ """Converts a snake_case attribute name into a camelCased string.
+
+ Args:
+ attr_name (str): The snake_case attribute name.
+ Returns:
+ str: The PascalCased attribute name.
+ """
+ return attr_name.replace("_", " ").title().replace(" ", "")
+
+ def to_json(self) -> Dict[str, Any]:
+ """Represents ``This`` object as JSON."""
+ json_obj = {}
+ for att in self.__slots__:
+ if hasattr(self, att):
+ cur_val = getattr(self, att)
+ att = self._convert_to_pascal_case(att)
+ json_obj[att] = self._val_to_json(cur_val)
+ return json_obj
+
+ def _val_to_json(self, val: Any) -> Any:
+ """Converts the given value to JSON.
+
+ Args:
+ val (Any): The value to convert.
+ Returns:
+ Any: The converted json value.
+ """
+ if issubclass(type(val), JumpStartDataHolderType):
+ if isinstance(val, JumpStartBenchmarkStat):
+ val.name = val.name.replace("_", " ").title()
+ return val.to_json()
+ if isinstance(val, list):
+ list_obj = []
+ for obj in val:
+ list_obj.append(self._val_to_json(obj))
+ return list_obj
+ if isinstance(val, dict):
+ dict_obj = {}
+ for k, v in val.items():
+ if isinstance(v, JumpStartDataHolderType):
+ dict_obj[self._convert_to_pascal_case(k)] = self._val_to_json(v)
+ else:
+ dict_obj[k] = self._val_to_json(v)
+ return dict_obj
+ return val
+
+
+class DeploymentArgs(BaseDeploymentConfigDataHolder):
+ """Dataclass representing a Deployment Args."""
+
+ __slots__ = [
+ "image_uri",
+ "model_data",
+ "model_package_arn",
+ "environment",
+ "instance_type",
+ "compute_resource_requirements",
+ "model_data_download_timeout",
+ "container_startup_health_check_timeout",
+ ]
+
+ def __init__(
+ self,
+ init_kwargs: Optional[JumpStartModelInitKwargs] = None,
+ deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
+ resolved_config: Optional[Dict[str, Any]] = None,
+ ):
+ """Instantiates DeploymentArgs object."""
+ if init_kwargs is not None:
+ self.image_uri = init_kwargs.image_uri
+ self.model_data = init_kwargs.model_data
+ self.model_package_arn = init_kwargs.model_package_arn
+ self.instance_type = init_kwargs.instance_type
+ self.environment = init_kwargs.env
+ if init_kwargs.resources is not None:
+ self.compute_resource_requirements = (
+ init_kwargs.resources.get_compute_resource_requirements()
+ )
+ if deploy_kwargs is not None:
+ self.model_data_download_timeout = deploy_kwargs.model_data_download_timeout
+ self.container_startup_health_check_timeout = (
+ deploy_kwargs.container_startup_health_check_timeout
+ )
+ if resolved_config is not None:
+ self.default_instance_type = resolved_config.get("default_inference_instance_type")
+ self.supported_instance_types = resolved_config.get(
+ "supported_inference_instance_types"
+ )
+
+
+class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
+ """Dataclass representing a Deployment Config Metadata"""
+
+ __slots__ = [
+ "deployment_config_name",
+ "deployment_args",
+ "acceleration_configs",
+ "benchmark_metrics",
+ ]
+
+ def __init__(
+ self,
+ config_name: Optional[str] = None,
+ metadata_config: Optional[JumpStartMetadataConfig] = None,
+ init_kwargs: Optional[JumpStartModelInitKwargs] = None,
+ deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
+ ):
+ """Instantiates DeploymentConfigMetadata object."""
+ self.deployment_config_name = config_name
+ self.deployment_args = DeploymentArgs(
+ init_kwargs, deploy_kwargs, metadata_config.resolved_config
+ )
+ self.benchmark_metrics = metadata_config.benchmark_metrics
+ self.acceleration_configs = metadata_config.acceleration_configs
diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py
index 63cfac0939..a1b5d7fa9a 100644
--- a/src/sagemaker/jumpstart/utils.py
+++ b/src/sagemaker/jumpstart/utils.py
@@ -12,11 +12,14 @@
# language governing permissions and limitations under the License.
"""This module contains utilities related to SageMaker JumpStart."""
from __future__ import absolute_import
+
import logging
import os
+from functools import lru_cache, wraps
from typing import Any, Dict, List, Set, Optional, Tuple, Union
from urllib.parse import urlparse
import boto3
+from botocore.exceptions import ClientError
from packaging.version import Version
import sagemaker
from sagemaker.config.config_schema import (
@@ -41,10 +44,11 @@
JumpStartModelHeader,
JumpStartModelSpecs,
JumpStartVersionedModelId,
+ DeploymentConfigMetadata,
)
from sagemaker.session import Session
from sagemaker.config import load_sagemaker_config
-from sagemaker.utils import resolve_value_from_config, TagsDict
+from sagemaker.utils import resolve_value_from_config, TagsDict, get_instance_rate_per_hour
from sagemaker.workflow import is_pipeline_variable
@@ -123,7 +127,7 @@ def get_jumpstart_gated_content_bucket(
def get_jumpstart_content_bucket(
region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
) -> str:
- """Returns regionalized content bucket name for JumpStart.
+ """Returns the regionalized content bucket name for JumpStart.
Raises:
ValueError: If JumpStart is not launched in ``region``.
@@ -318,6 +322,8 @@ def add_single_jumpstart_tag(
tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags)
or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags)
or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags)
+ or tag_key_in_array(enums.JumpStartTag.INFERENCE_CONFIG_NAME, curr_tags)
+ or tag_key_in_array(enums.JumpStartTag.TRAINING_CONFIG_NAME, curr_tags)
)
if is_uri
else False
@@ -348,11 +354,13 @@ def get_jumpstart_base_name_if_jumpstart_model(
return None
-def add_jumpstart_model_id_version_tags(
+def add_jumpstart_model_info_tags(
tags: Optional[List[TagsDict]],
model_id: str,
model_version: str,
model_type: Optional[enums.JumpStartModelType] = None,
+ config_name: Optional[str] = None,
+ scope: enums.JumpStartScriptScope = None,
) -> List[TagsDict]:
"""Add custom model ID and version tags to JumpStart related resources."""
if model_id is None or model_version is None:
@@ -376,6 +384,20 @@ def add_jumpstart_model_id_version_tags(
tags,
is_uri=False,
)
+ if config_name and scope == enums.JumpStartScriptScope.INFERENCE:
+ tags = add_single_jumpstart_tag(
+ config_name,
+ enums.JumpStartTag.INFERENCE_CONFIG_NAME,
+ tags,
+ is_uri=False,
+ )
+ if config_name and scope == enums.JumpStartScriptScope.TRAINING:
+ tags = add_single_jumpstart_tag(
+ config_name,
+ enums.JumpStartTag.TRAINING_CONFIG_NAME,
+ tags,
+ is_uri=False,
+ )
return tags
@@ -547,6 +569,7 @@ def verify_model_region_and_return_specs(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> JumpStartModelSpecs:
"""Verifies that an acceptable model_id, version, scope, and region combination is provided.
@@ -569,6 +592,7 @@ def verify_model_region_and_return_specs(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Raises:
NotImplementedError: If the scope is not supported.
@@ -634,6 +658,9 @@ def verify_model_region_and_return_specs(
scope=constants.JumpStartScriptScope.TRAINING,
)
+ if model_specs and config_name:
+ model_specs.set_config(config_name, scope)
+
return model_specs
@@ -795,52 +822,80 @@ def validate_model_id_and_get_type(
return None
-def get_jumpstart_model_id_version_from_resource_arn(
+def _extract_value_from_list_of_tags(
+ tag_keys: List[str],
+ list_tags_result: List[str],
+ resource_name: str,
+ resource_arn: str,
+):
+ """Extracts value from list of tags with check of duplicate tags.
+
+ Returns None if no value is found.
+ """
+ resolved_value = None
+ for tag_key in tag_keys:
+ try:
+ value_from_tag = get_tag_value(tag_key, list_tags_result)
+ except KeyError:
+ continue
+ if value_from_tag is not None:
+ if resolved_value is not None and value_from_tag != resolved_value:
+ constants.JUMPSTART_LOGGER.warning(
+ "Found multiple %s tags on the following resource: %s",
+ resource_name,
+ resource_arn,
+ )
+ resolved_value = None
+ break
+ resolved_value = value_from_tag
+ return resolved_value
+
+
+def get_jumpstart_model_info_from_resource_arn(
resource_arn: str,
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
-) -> Tuple[Optional[str], Optional[str]]:
- """Returns the JumpStart model ID and version if in resource tags.
+) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
+ """Returns the JumpStart model ID, version and config name if in resource tags.
- Returns 'None' if model ID or version cannot be inferred from tags.
+ Returns 'None' if model ID or version or config name cannot be inferred from tags.
"""
list_tags_result = sagemaker_session.list_tags(resource_arn)
- model_id: Optional[str] = None
- model_version: Optional[str] = None
-
model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS]
model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS]
+ inference_config_name_keys = [enums.JumpStartTag.INFERENCE_CONFIG_NAME]
+ training_config_name_keys = [enums.JumpStartTag.TRAINING_CONFIG_NAME]
+
+ model_id: Optional[str] = _extract_value_from_list_of_tags(
+ tag_keys=model_id_keys,
+ list_tags_result=list_tags_result,
+ resource_name="model ID",
+ resource_arn=resource_arn,
+ )
- for model_id_key in model_id_keys:
- try:
- model_id_from_tag = get_tag_value(model_id_key, list_tags_result)
- except KeyError:
- continue
- if model_id_from_tag is not None:
- if model_id is not None and model_id_from_tag != model_id:
- constants.JUMPSTART_LOGGER.warning(
- "Found multiple model ID tags on the following resource: %s", resource_arn
- )
- model_id = None
- break
- model_id = model_id_from_tag
+ model_version: Optional[str] = _extract_value_from_list_of_tags(
+ tag_keys=model_version_keys,
+ list_tags_result=list_tags_result,
+ resource_name="model version",
+ resource_arn=resource_arn,
+ )
- for model_version_key in model_version_keys:
- try:
- model_version_from_tag = get_tag_value(model_version_key, list_tags_result)
- except KeyError:
- continue
- if model_version_from_tag is not None:
- if model_version is not None and model_version_from_tag != model_version:
- constants.JUMPSTART_LOGGER.warning(
- "Found multiple model version tags on the following resource: %s", resource_arn
- )
- model_version = None
- break
- model_version = model_version_from_tag
+ inference_config_name: Optional[str] = _extract_value_from_list_of_tags(
+ tag_keys=inference_config_name_keys,
+ list_tags_result=list_tags_result,
+ resource_name="inference config name",
+ resource_arn=resource_arn,
+ )
+
+ training_config_name: Optional[str] = _extract_value_from_list_of_tags(
+ tag_keys=training_config_name_keys,
+ list_tags_result=list_tags_result,
+ resource_name="training config name",
+ resource_arn=resource_arn,
+ )
- return model_id, model_version
+ return model_id, model_version, inference_config_name, training_config_name
def get_region_fallback(
@@ -890,7 +945,11 @@ def get_config_names(
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
) -> List[str]:
- """Returns a list of config names for the given model ID and region."""
+ """Returns a list of config names for the given model ID and region.
+
+ Raises:
+ ValueError: If the script scope is not supported by JumpStart.
+ """
model_specs = verify_model_region_and_return_specs(
region=region,
model_id=model_id,
@@ -905,7 +964,7 @@ def get_config_names(
elif scope == enums.JumpStartScriptScope.TRAINING:
metadata_configs = model_specs.training_configs
else:
- raise ValueError(f"Unknown script scope {scope}.")
+ raise ValueError(f"Unknown script scope: {scope}.")
return list(metadata_configs.configs.keys()) if metadata_configs else []
@@ -919,7 +978,11 @@ def get_benchmark_stats(
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
) -> Dict[str, List[JumpStartBenchmarkStat]]:
- """Returns benchmark stats for the given model ID and region."""
+ """Returns benchmark stats for the given model ID and region.
+
+ Raises:
+ ValueError: If the script scope is not supported by JumpStart.
+ """
model_specs = verify_model_region_and_return_specs(
region=region,
model_id=model_id,
@@ -934,7 +997,7 @@ def get_benchmark_stats(
elif scope == enums.JumpStartScriptScope.TRAINING:
metadata_configs = model_specs.training_configs
else:
- raise ValueError(f"Unknown script scope {scope}.")
+ raise ValueError(f"Unknown script scope: {scope}.")
if not config_names:
config_names = metadata_configs.configs.keys() if metadata_configs else []
@@ -942,7 +1005,7 @@ def get_benchmark_stats(
benchmark_stats = {}
for config_name in config_names:
if config_name not in metadata_configs.configs:
- raise ValueError(f"Unknown config name: '{config_name}'")
+ raise ValueError(f"Unknown config name: {config_name}")
benchmark_stats[config_name] = metadata_configs.configs.get(config_name).benchmark_metrics
return benchmark_stats
@@ -956,8 +1019,12 @@ def get_jumpstart_configs(
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
-) -> Dict[str, List[JumpStartMetadataConfig]]:
- """Returns metadata configs for the given model ID and region."""
+) -> Dict[str, JumpStartMetadataConfig]:
+ """Returns metadata configs for the given model ID and region.
+
+ Raises:
+ ValueError: If the script scope is not supported by JumpStart.
+ """
model_specs = verify_model_region_and_return_specs(
region=region,
model_id=model_id,
@@ -972,13 +1039,262 @@ def get_jumpstart_configs(
elif scope == enums.JumpStartScriptScope.TRAINING:
metadata_configs = model_specs.training_configs
else:
- raise ValueError(f"Unknown script scope {scope}.")
+ raise ValueError(f"Unknown script scope: {scope}.")
if not config_names:
- config_names = metadata_configs.configs.keys() if metadata_configs else []
+ config_names = (
+ metadata_configs.config_rankings.get("overall").rankings if metadata_configs else []
+ )
return (
{config_name: metadata_configs.configs[config_name] for config_name in config_names}
if metadata_configs
else {}
)
+
+
+def add_instance_rate_stats_to_benchmark_metrics(
+ region: str,
+ benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]],
+) -> Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]:
+ """Adds instance types metric stats to the given benchmark_metrics dict.
+
+ Args:
+ region (str): AWS region.
+ benchmark_metrics (Optional[Dict[str, List[JumpStartBenchmarkStat]]]):
+ Returns:
+ Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]:
+ Contains Error and metrics.
+ """
+ if not benchmark_metrics:
+ return None
+
+ err_message = None
+ final_benchmark_metrics = {}
+ for instance_type, benchmark_metric_stats in benchmark_metrics.items():
+ instance_type = instance_type if instance_type.startswith("ml.") else f"ml.{instance_type}"
+
+ if not has_instance_rate_stat(benchmark_metric_stats) and not err_message:
+ try:
+ instance_type_rate = get_instance_rate_per_hour(
+ instance_type=instance_type, region=region
+ )
+
+ if not benchmark_metric_stats:
+ benchmark_metric_stats = []
+ benchmark_metric_stats.append(
+ JumpStartBenchmarkStat({"concurrency": None, **instance_type_rate})
+ )
+
+ final_benchmark_metrics[instance_type] = benchmark_metric_stats
+ except ClientError as e:
+ final_benchmark_metrics[instance_type] = benchmark_metric_stats
+ err_message = e.response["Error"]
+ except Exception: # pylint: disable=W0703
+ final_benchmark_metrics[instance_type] = benchmark_metric_stats
+ else:
+ final_benchmark_metrics[instance_type] = benchmark_metric_stats
+
+ return err_message, final_benchmark_metrics
+
+
+def has_instance_rate_stat(benchmark_metric_stats: Optional[List[JumpStartBenchmarkStat]]) -> bool:
+ """Determines whether a benchmark metric stats contains instance rate metric stat.
+
+ Args:
+ benchmark_metric_stats (Optional[List[JumpStartBenchmarkStat]]):
+ List of benchmark metric stats.
+ Returns:
+ bool: Whether the benchmark metric stats contains instance rate metric stat.
+ """
+ if benchmark_metric_stats is None:
+ return True
+ for benchmark_metric_stat in benchmark_metric_stats:
+ if benchmark_metric_stat.name.lower() == "instance rate":
+ return True
+ return False
+
+
+def get_metrics_from_deployment_configs(
+ deployment_configs: Optional[List[DeploymentConfigMetadata]],
+) -> Dict[str, List[str]]:
+ """Extracts benchmark metrics from deployment configs metadata.
+
+ Args:
+ deployment_configs (Optional[List[DeploymentConfigMetadata]]):
+ List of deployment configs metadata.
+ Returns:
+ Dict[str, List[str]]: Deployment configs bench metrics dict.
+ """
+ if not deployment_configs:
+ return {}
+
+ data = {"Instance Type": [], "Config Name": [], "Concurrent Users": []}
+ instance_rate_data = {}
+ for index, deployment_config in enumerate(deployment_configs):
+ benchmark_metrics = deployment_config.benchmark_metrics
+ if not deployment_config.deployment_args or not benchmark_metrics:
+ continue
+
+ for current_instance_type, current_instance_type_metrics in benchmark_metrics.items():
+ instance_type_rate, concurrent_users = _normalize_benchmark_metrics(
+ current_instance_type_metrics
+ )
+
+ for concurrent_user, metrics in concurrent_users.items():
+ instance_type_to_display = (
+ f"{current_instance_type} (Default)"
+ if index == 0
+ and int(concurrent_user) == 1
+ and current_instance_type
+ == deployment_config.deployment_args.default_instance_type
+ else current_instance_type
+ )
+
+ data["Config Name"].append(deployment_config.deployment_config_name)
+ data["Instance Type"].append(instance_type_to_display)
+ data["Concurrent Users"].append(concurrent_user)
+
+ if instance_type_rate:
+ instance_rate_column_name = (
+ f"{instance_type_rate.name} ({instance_type_rate.unit})"
+ )
+ instance_rate_data[instance_rate_column_name] = instance_rate_data.get(
+ instance_rate_column_name, []
+ )
+ instance_rate_data[instance_rate_column_name].append(instance_type_rate.value)
+
+ for metric in metrics:
+ column_name = _normalize_benchmark_metric_column_name(metric.name)
+ data[column_name] = data.get(column_name, [])
+ data[column_name].append(metric.value)
+
+ data = {**data, **instance_rate_data}
+ return data
+
+
+def _normalize_benchmark_metric_column_name(name: str) -> str:
+ """Normalizes benchmark metric column name.
+
+ Args:
+ name (str): Name of the metric.
+ Returns:
+ str: Normalized metric column name.
+ """
+ if "latency" in name.lower():
+ name = "Latency for each user (TTFT in ms)"
+ elif "throughput" in name.lower():
+ name = "Throughput per user (token/seconds)"
+ return name
+
+
+def _normalize_benchmark_metrics(
+ benchmark_metric_stats: List[JumpStartBenchmarkStat],
+) -> Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]:
+ """Normalizes benchmark metrics dict.
+
+ Args:
+ benchmark_metric_stats (List[JumpStartBenchmarkStat]):
+ List of benchmark metrics stats.
+ Returns:
+ Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]:
+ Normalized benchmark metrics dict.
+ """
+ instance_type_rate = None
+ concurrent_users = {}
+ for current_instance_type_metric in benchmark_metric_stats:
+ if current_instance_type_metric.name.lower() == "instance rate":
+ instance_type_rate = current_instance_type_metric
+ elif current_instance_type_metric.concurrency not in concurrent_users:
+ concurrent_users[current_instance_type_metric.concurrency] = [
+ current_instance_type_metric
+ ]
+ else:
+ concurrent_users[current_instance_type_metric.concurrency].append(
+ current_instance_type_metric
+ )
+
+ return instance_type_rate, concurrent_users
+
+
+def deployment_config_response_data(
+ deployment_configs: Optional[List[DeploymentConfigMetadata]],
+) -> List[Dict[str, Any]]:
+ """Deployment config api response data.
+
+ Args:
+ deployment_configs (Optional[List[DeploymentConfigMetadata]]):
+ List of deployment configs metadata.
+ Returns:
+ List[Dict[str, Any]]: List of deployment config api response data.
+ """
+ configs = []
+ if not deployment_configs:
+ return configs
+
+ for deployment_config in deployment_configs:
+ deployment_config_json = deployment_config.to_json()
+ benchmark_metrics = deployment_config_json.get("BenchmarkMetrics")
+ if benchmark_metrics and deployment_config.deployment_args:
+ deployment_config_json["BenchmarkMetrics"] = {
+ deployment_config.deployment_args.instance_type: benchmark_metrics.get(
+ deployment_config.deployment_args.instance_type
+ )
+ }
+
+ configs.append(deployment_config_json)
+ return configs
+
+
+def _deployment_config_lru_cache(_func=None, *, maxsize: int = 128, typed: bool = False):
+ """LRU cache for deployment configs."""
+
+ def has_instance_rate_metric(config: DeploymentConfigMetadata) -> bool:
+ """Determines whether metadata config contains instance rate metric stat.
+
+ Args:
+ config (DeploymentConfigMetadata): Metadata config metadata.
+ Returns:
+ bool: Whether the metadata config contains instance rate metric stat.
+ """
+ if config.benchmark_metrics is None:
+ return True
+ for benchmark_metric_stats in config.benchmark_metrics.values():
+ if not has_instance_rate_stat(benchmark_metric_stats):
+ return False
+ return True
+
+ def wrapper_cache(f):
+ f = lru_cache(maxsize=maxsize, typed=typed)(f)
+
+ @wraps(f)
+ def wrapped_f(*args, **kwargs):
+ res = f(*args, **kwargs)
+
+ # Clear cache on first call if
+ # - The output does not contain Instant rate metrics
+ # as this is caused by missing policy.
+ if f.cache_info().hits == 0 and f.cache_info().misses == 1:
+ if isinstance(res, list):
+ for item in res:
+ if isinstance(
+ item, DeploymentConfigMetadata
+ ) and not has_instance_rate_metric(item):
+ f.cache_clear()
+ break
+ elif isinstance(res, dict):
+ keys = list(res.keys())
+ if len(keys) == 0 or "Instance Rate" not in keys[-1]:
+ f.cache_clear()
+ elif len(res[keys[1]]) > len(res[keys[-1]]):
+ del res[keys[-1]]
+ f.cache_clear()
+ return res
+
+ wrapped_f.cache_info = f.cache_info
+ wrapped_f.cache_clear = f.cache_clear
+ return wrapped_f
+
+ if _func is None:
+ return wrapper_cache
+ return wrapper_cache(_func)
diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py
index c7098a1185..bcb0365f7b 100644
--- a/src/sagemaker/jumpstart/validators.py
+++ b/src/sagemaker/jumpstart/validators.py
@@ -171,6 +171,7 @@ def validate_hyperparameters(
sagemaker_session: Optional[session.Session] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
+ config_name: Optional[str] = None,
) -> None:
"""Validate hyperparameters for JumpStart models.
@@ -193,6 +194,7 @@ def validate_hyperparameters(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Raises:
JumpStartHyperparametersError: If the hyperparameters are not formatted correctly,
@@ -218,6 +220,7 @@ def validate_hyperparameters(
sagemaker_session=sagemaker_session,
tolerate_deprecated_model=tolerate_deprecated_model,
tolerate_vulnerable_model=tolerate_vulnerable_model,
+ config_name=config_name,
)
hyperparameters_specs = model_specs.hyperparameters
diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py
index 182f117913..3e2003674b 100644
--- a/src/sagemaker/lineage/query.py
+++ b/src/sagemaker/lineage/query.py
@@ -335,8 +335,8 @@ def _get_legend_line(self, component_name):
def _add_legend(self, path):
"""Embed legend to html file generated by pyvis."""
- f = open(path, "r")
- content = self.BeautifulSoup(f, "html.parser")
+ with open(path, "r") as f:
+ content = self.BeautifulSoup(f, "html.parser")
legend = """
str:
"""Retrieves the model artifact Amazon S3 URI for the model matching the given arguments.
@@ -57,6 +58,8 @@ def retrieve(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
+
Returns:
str: The model artifact S3 URI for the corresponding model.
@@ -81,4 +84,5 @@ def retrieve(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py
index 6f846bba65..780a1a56c8 100644
--- a/src/sagemaker/predictor.py
+++ b/src/sagemaker/predictor.py
@@ -18,7 +18,7 @@
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart.factory.model import get_default_predictor
-from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint
+from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint
from sagemaker.session import Session
@@ -43,6 +43,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+ config_name: Optional[str] = None,
) -> Predictor:
"""Retrieves the default predictor for the model matching the given arguments.
@@ -65,6 +66,8 @@ def retrieve_default(
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
(exception not raised). False if these models should raise an exception.
(Default: False).
+ config_name (Optional[str]): The name of the configuration to use for the
+ predictor. (Default: None)
Returns:
Predictor: The default predictor to use for the model.
@@ -78,9 +81,9 @@ def retrieve_default(
inferred_model_id,
inferred_model_version,
inferred_inference_component_name,
- ) = get_model_id_version_from_endpoint(
- endpoint_name, inference_component_name, sagemaker_session
- )
+ inferred_config_name,
+ _,
+ ) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session)
if not inferred_model_id:
raise ValueError(
@@ -92,6 +95,7 @@ def retrieve_default(
model_id = inferred_model_id
model_version = model_version or inferred_model_version or "*"
inference_component_name = inference_component_name or inferred_inference_component_name
+ config_name = config_name or inferred_config_name or None
else:
model_version = model_version or "*"
@@ -110,4 +114,5 @@ def retrieve_default(
tolerate_vulnerable_model=tolerate_vulnerable_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py
index a4e24d1ff0..412926279c 100644
--- a/src/sagemaker/pytorch/estimator.py
+++ b/src/sagemaker/pytorch/estimator.py
@@ -276,6 +276,20 @@ def __init__(
kwargs["entry_point"] = entry_point
if distribution is not None:
+ # rewrite pytorchddp to smdistributed
+ if "pytorchddp" in distribution:
+ if "smdistributed" in distribution:
+ raise ValueError(
+ "Cannot use both pytorchddp and smdistributed "
+ "distribution options together.",
+ distribution,
+ )
+
+ # convert pytorchddp distribution into smdistributed distribution
+ distribution = distribution.copy()
+ distribution["smdistributed"] = {"dataparallel": distribution["pytorchddp"]}
+ del distribution["pytorchddp"]
+
distribution = validate_distribution(
distribution,
self.instance_groups,
diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py
index df14ac558f..7808d0172a 100644
--- a/src/sagemaker/resource_requirements.py
+++ b/src/sagemaker/resource_requirements.py
@@ -37,6 +37,7 @@ def retrieve_default(
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
+ config_name: Optional[str] = None,
) -> ResourceRequirements:
"""Retrieves the default resource requirements for the model matching the given arguments.
@@ -62,6 +63,7 @@ def retrieve_default(
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
host requirements specific for the instance type.
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default resource requirements to use for the model.
@@ -87,4 +89,5 @@ def retrieve_default(
model_type=model_type,
sagemaker_session=sagemaker_session,
instance_type=instance_type,
+ config_name=config_name,
)
diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py
index 9a1c4933d2..6e10785498 100644
--- a/src/sagemaker/script_uris.py
+++ b/src/sagemaker/script_uris.py
@@ -33,6 +33,7 @@ def retrieve(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> str:
"""Retrieves the script S3 URI associated with the model matching the given arguments.
@@ -55,6 +56,7 @@ def retrieve(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The model script URI for the corresponding model.
@@ -78,4 +80,5 @@ def retrieve(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py
index aefb52bd97..d197df731c 100644
--- a/src/sagemaker/serializers.py
+++ b/src/sagemaker/serializers.py
@@ -45,6 +45,7 @@ def retrieve_options(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> List[BaseSerializer]:
"""Retrieves the supported serializers for the model matching the given arguments.
@@ -66,6 +67,7 @@ def retrieve_options(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
List[SimpleBaseSerializer]: The supported serializers to use for the model.
@@ -85,6 +87,7 @@ def retrieve_options(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
+ config_name=config_name,
)
@@ -96,6 +99,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
+ config_name: Optional[str] = None,
) -> BaseSerializer:
"""Retrieves the default serializer for the model matching the given arguments.
@@ -117,6 +121,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
SimpleBaseSerializer: The default serializer to use for the model.
@@ -137,4 +142,5 @@ def retrieve_default(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
+ config_name=config_name,
)
diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py
index e3368869fe..e8ef546f7a 100644
--- a/src/sagemaker/serve/builder/jumpstart_builder.py
+++ b/src/sagemaker/serve/builder/jumpstart_builder.py
@@ -16,13 +16,14 @@
import copy
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
-from typing import Type
+from typing import Type, Any, List, Dict, Optional
import logging
from sagemaker.model import Model
from sagemaker import model_uris
from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources
from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees
+from sagemaker.serve.model_server.multi_model_server.prepare import prepare_mms_js_resources
from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.exceptions import (
@@ -35,6 +36,7 @@
from sagemaker.serve.utils.predictors import (
DjlLocalModePredictor,
TgiLocalModePredictor,
+ TransformersLocalModePredictor,
)
from sagemaker.serve.utils.local_hardware import (
_get_nb_instance,
@@ -90,6 +92,7 @@ def __init__(self):
self.existing_properties = None
self.prepared_for_tgi = None
self.prepared_for_djl = None
+ self.prepared_for_mms = None
self.schema_builder = None
self.nb_instance_type = None
self.ram_usage_model_load = None
@@ -137,7 +140,11 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
- if not hasattr(self, "prepared_for_djl") or not hasattr(self, "prepared_for_tgi"):
+ if (
+ not hasattr(self, "prepared_for_djl")
+ or not hasattr(self, "prepared_for_tgi")
+ or not hasattr(self, "prepared_for_mms")
+ ):
self.pysdk_model.model_data, env = self._prepare_for_mode()
elif overwrite_mode == Mode.LOCAL_CONTAINER:
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
@@ -160,6 +167,13 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
dependencies=self.dependencies,
model_data=self.pysdk_model.model_data,
)
+ elif not hasattr(self, "prepared_for_mms"):
+ self.js_model_config, self.prepared_for_mms = prepare_mms_js_resources(
+ model_path=self.model_path,
+ js_id=self.model,
+ dependencies=self.dependencies,
+ model_data=self.pysdk_model.model_data,
+ )
self._prepare_for_mode()
env = {}
@@ -179,6 +193,10 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
predictor = TgiLocalModePredictor(
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
)
+ elif self.model_server == ModelServer.MMS:
+ predictor = TransformersLocalModePredictor(
+ self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
+ )
ram_usage_before = _get_ram_usage_mb()
self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
@@ -254,6 +272,24 @@ def _build_for_tgi_jumpstart(self):
self.pysdk_model.env.update(env)
+ def _build_for_mms_jumpstart(self):
+ """Placeholder docstring"""
+
+ env = {}
+ if self.mode == Mode.LOCAL_CONTAINER:
+ if not hasattr(self, "prepared_for_mms"):
+ self.js_model_config, self.prepared_for_mms = prepare_mms_js_resources(
+ model_path=self.model_path,
+ js_id=self.model,
+ dependencies=self.dependencies,
+ model_data=self.pysdk_model.model_data,
+ )
+ self._prepare_for_mode()
+ elif self.mode == Mode.SAGEMAKER_ENDPOINT and hasattr(self, "prepared_for_mms"):
+ self.pysdk_model.model_data, env = self._prepare_for_mode()
+
+ self.pysdk_model.env.update(env)
+
def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800):
"""Tune for Jumpstart Models in Local Mode.
@@ -264,7 +300,7 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800)
returns:
Tuned Model.
"""
- if self.mode != Mode.LOCAL_CONTAINER:
+ if self.mode == Mode.SAGEMAKER_ENDPOINT:
logger.warning(
"Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER
)
@@ -431,14 +467,61 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
)
+ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
+ """Sets the deployment config to apply to the model.
+
+ Args:
+ config_name (str):
+ The name of the deployment config to apply to the model.
+ Call list_deployment_configs to see the list of config names.
+ instance_type (str):
+ The instance_type that the model will use after setting
+ the config.
+ """
+ if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
+ raise Exception("Cannot set deployment config to an uninitialized model.")
+
+ self.pysdk_model.set_deployment_config(config_name, instance_type)
+
+ def get_deployment_config(self) -> Optional[Dict[str, Any]]:
+ """Gets the deployment config to apply to the model.
+
+ Returns:
+ Optional[Dict[str, Any]]: Deployment config to apply to this model.
+ """
+ if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
+ self._build_for_jumpstart()
+
+ return self.pysdk_model.deployment_config
+
+ def display_benchmark_metrics(self, **kwargs):
+ """Display Markdown Benchmark Metrics for deployment configs."""
+ if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
+ self._build_for_jumpstart()
+
+ self.pysdk_model.display_benchmark_metrics(**kwargs)
+
+ def list_deployment_configs(self) -> List[Dict[str, Any]]:
+ """List deployment configs for ``This`` model in the current region.
+
+ Returns:
+ List[Dict[str, Any]]: A list of deployment configs.
+ """
+ if not hasattr(self, "pysdk_model") or self.pysdk_model is None:
+ self._build_for_jumpstart()
+
+ return self.pysdk_model.list_deployment_configs()
+
def _build_for_jumpstart(self):
"""Placeholder docstring"""
+ if hasattr(self, "pysdk_model") and self.pysdk_model is not None:
+ return self.pysdk_model
+
# we do not pickle for jumpstart. set to none
self.secret_key = None
self.jumpstart = True
pysdk_model = self._create_pre_trained_js_model()
-
image_uri = pysdk_model.image_uri
logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)
@@ -451,7 +534,6 @@ def _build_for_jumpstart(self):
if "djl-inference" in image_uri:
logger.info("Building for DJL JumpStart Model ID...")
self.model_server = ModelServer.DJL_SERVING
-
self.pysdk_model = pysdk_model
self.image_uri = self.pysdk_model.image_uri
@@ -461,16 +543,23 @@ def _build_for_jumpstart(self):
elif "tgi-inference" in image_uri:
logger.info("Building for TGI JumpStart Model ID...")
self.model_server = ModelServer.TGI
-
self.pysdk_model = pysdk_model
self.image_uri = self.pysdk_model.image_uri
self._build_for_tgi_jumpstart()
self.pysdk_model.tune = self.tune_for_tgi_jumpstart
- else:
+ elif "huggingface-pytorch-inference:" in image_uri:
+ logger.info("Building for MMS JumpStart Model ID...")
+ self.model_server = ModelServer.MMS
+ self.pysdk_model = pysdk_model
+ self.image_uri = self.pysdk_model.image_uri
+
+ self._build_for_mms_jumpstart()
+ elif self.mode != Mode.SAGEMAKER_ENDPOINT:
raise ValueError(
- "JumpStart Model ID was not packaged with djl-inference or tgi-inference container."
+ "JumpStart Model ID was not packaged "
+ "with djl-inference, tgi-inference, or mms-inference container."
)
return self.pysdk_model
diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py
index 06b3d70aeb..44bc46b00b 100644
--- a/src/sagemaker/serve/builder/model_builder.py
+++ b/src/sagemaker/serve/builder/model_builder.py
@@ -29,12 +29,14 @@
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer
from sagemaker.serve.builder.schema_builder import SchemaBuilder
+from sagemaker.serve.builder.tf_serving_builder import TensorflowServing
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
from sagemaker.serve.detector.pickler import save_pkl, save_xgboost
from sagemaker.serve.builder.serve_settings import _ServeSettings
from sagemaker.serve.builder.djl_builder import DJL
+from sagemaker.serve.builder.tei_builder import TEI
from sagemaker.serve.builder.tgi_builder import TGI
from sagemaker.serve.builder.jumpstart_builder import JumpStart
from sagemaker.serve.builder.transformers_builder import Transformers
@@ -59,6 +61,7 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.utils import task
from sagemaker.serve.utils.exceptions import TaskNotFoundException
+from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
from sagemaker.serve.utils.hardware_detector import (
_get_gpu_info,
@@ -89,12 +92,13 @@
ModelServer.TORCHSERVE,
ModelServer.TRITON,
ModelServer.DJL_SERVING,
+ ModelServer.TENSORFLOW_SERVING,
}
-# pylint: disable=attribute-defined-outside-init, disable=E1101
+# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901, disable=R1705
@dataclass
-class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
+class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, TEI):
"""Class that builds a deployable model.
Args:
@@ -165,7 +169,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers):
in order for model builder to build the artifacts correctly (according
to the model server). Possible values for this argument are
``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``,
- ``TRITON``, and``TGI``.
+ ``TRITON``, ``TGI``, and ``TEI``.
model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata.
Currently, ``HF_TASK`` is overridable for HuggingFace model. HF_TASK should be set for
new models without task metadata in the Hub, adding unsupported task types will throw
@@ -493,6 +497,12 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
self.pysdk_model.model_package_arn = new_model_package.model_package_arn
new_model_package.deploy = self._model_builder_deploy_model_package_wrapper
self.model_package = new_model_package
+ if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT:
+ _maintain_lineage_tracking_for_mlflow_model(
+ mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
+ s3_upload_path=self.s3_upload_path,
+ sagemaker_session=self.sagemaker_session,
+ )
return new_model_package
def _model_builder_deploy_model_package_wrapper(self, *args, **kwargs):
@@ -551,12 +561,19 @@ def _model_builder_deploy_wrapper(
if "endpoint_logging" not in kwargs:
kwargs["endpoint_logging"] = True
- return self._original_deploy(
+ predictor = self._original_deploy(
*args,
instance_type=instance_type,
initial_instance_count=initial_instance_count,
**kwargs,
)
+ if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT:
+ _maintain_lineage_tracking_for_mlflow_model(
+ mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
+ s3_upload_path=self.s3_upload_path,
+ sagemaker_session=self.sagemaker_session,
+ )
+ return predictor
def _overwrite_mode_in_deploy(self, overwrite_mode: str):
"""Mode overwritten by customer during model.deploy()"""
@@ -653,6 +670,9 @@ def _initialize_for_mlflow(self) -> None:
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
if not _mlflow_input_is_local_path(mlflow_path):
# TODO: extend to package arn, run id and etc.
+ logger.info(
+ "Start downloading model artifacts from %s to %s", mlflow_path, self.model_path
+ )
_download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session)
else:
_copy_directory_contents(mlflow_path, self.model_path)
@@ -708,8 +728,6 @@ def build( # pylint: disable=R0911
self.role_arn = role_arn
self.sagemaker_session = sagemaker_session or Session()
- self.sagemaker_session.settings._local_download_dir = self.model_path
-
# https://github.com/boto/botocore/blob/develop/botocore/useragent.py#L258
# decorate to_string() due to
# https://github.com/boto/botocore/blob/develop/botocore/client.py#L1014-L1015
@@ -728,7 +746,7 @@ def build( # pylint: disable=R0911
" for production at this moment."
)
self._initialize_for_mlflow()
- _validate_input_for_mlflow(self.model_server)
+ _validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR"))
if isinstance(self.model, str):
model_task = None
@@ -736,7 +754,7 @@ def build( # pylint: disable=R0911
model_task = self.model_metadata.get("HF_TASK")
if self._is_jumpstart_model_id():
return self._build_for_jumpstart()
- if self._is_djl(): # pylint: disable=R1705
+ if self._is_djl():
return self._build_for_djl()
else:
hf_model_md = get_huggingface_model_metadata(
@@ -747,8 +765,10 @@ def build( # pylint: disable=R0911
model_task = hf_model_md.get("pipeline_tag")
if self.schema_builder is None and model_task is not None:
self._hf_schema_builder_init(model_task)
- if model_task == "text-generation": # pylint: disable=R1705
+ if model_task == "text-generation":
return self._build_for_tgi()
+ if model_task == "sentence-similarity":
+ return self._build_for_tei()
elif self._can_fit_on_single_gpu():
return self._build_for_transformers()
elif (
@@ -767,6 +787,9 @@ def build( # pylint: disable=R0911
if self.model_server == ModelServer.TRITON:
return self._build_for_triton()
+ if self.model_server == ModelServer.TENSORFLOW_SERVING:
+ return self._build_for_tensorflow_serving()
+
raise ValueError("%s model server is not supported" % self.model_server)
def save(
diff --git a/src/sagemaker/serve/builder/tei_builder.py b/src/sagemaker/serve/builder/tei_builder.py
new file mode 100644
index 0000000000..6aba3c9da2
--- /dev/null
+++ b/src/sagemaker/serve/builder/tei_builder.py
@@ -0,0 +1,224 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Holds mixin logic to support deployment of Model ID"""
+from __future__ import absolute_import
+import logging
+from typing import Type
+from abc import ABC, abstractmethod
+
+from sagemaker import image_uris
+from sagemaker.model import Model
+from sagemaker.djl_inference.model import _get_model_config_properties_from_hf
+
+from sagemaker.huggingface import HuggingFaceModel
+from sagemaker.serve.utils.local_hardware import (
+ _get_nb_instance,
+)
+from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
+from sagemaker.serve.utils.predictors import TeiLocalModePredictor
+from sagemaker.serve.utils.types import ModelServer
+from sagemaker.serve.mode.function_pointers import Mode
+from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
+from sagemaker.base_predictor import PredictorBase
+
+logger = logging.getLogger(__name__)
+
+_CODE_FOLDER = "code"
+
+
+class TEI(ABC):
+ """TEI build logic for ModelBuilder()"""
+
+ def __init__(self):
+ self.model = None
+ self.serve_settings = None
+ self.sagemaker_session = None
+ self.model_path = None
+ self.dependencies = None
+ self.modes = None
+ self.mode = None
+ self.model_server = None
+ self.image_uri = None
+ self._is_custom_image_uri = False
+ self.image_config = None
+ self.vpc_config = None
+ self._original_deploy = None
+ self.hf_model_config = None
+ self._default_tensor_parallel_degree = None
+ self._default_data_type = None
+ self._default_max_tokens = None
+ self.pysdk_model = None
+ self.schema_builder = None
+ self.env_vars = None
+ self.nb_instance_type = None
+ self.ram_usage_model_load = None
+ self.secret_key = None
+ self.jumpstart = None
+ self.role_arn = None
+
+ @abstractmethod
+ def _prepare_for_mode(self):
+ """Placeholder docstring"""
+
+ @abstractmethod
+ def _get_client_translators(self):
+ """Placeholder docstring"""
+
+ def _set_to_tei(self):
+ """Placeholder docstring"""
+ if self.model_server != ModelServer.TEI:
+ messaging = (
+ "HuggingFace Model ID support on model server: "
+ f"{self.model_server} is not currently supported. "
+ f"Defaulting to {ModelServer.TEI}"
+ )
+ logger.warning(messaging)
+ self.model_server = ModelServer.TEI
+
+ def _create_tei_model(self, **kwargs) -> Type[Model]:
+ """Placeholder docstring"""
+ if self.nb_instance_type and "instance_type" not in kwargs:
+ kwargs.update({"instance_type": self.nb_instance_type})
+
+ if not self.image_uri:
+ self.image_uri = image_uris.retrieve(
+ "huggingface-tei",
+ image_scope="inference",
+ instance_type=kwargs.get("instance_type"),
+ region=self.sagemaker_session.boto_region_name,
+ )
+
+ pysdk_model = HuggingFaceModel(
+ image_uri=self.image_uri,
+ image_config=self.image_config,
+ vpc_config=self.vpc_config,
+ env=self.env_vars,
+ role=self.role_arn,
+ sagemaker_session=self.sagemaker_session,
+ )
+
+ logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
+
+ self._original_deploy = pysdk_model.deploy
+ pysdk_model.deploy = self._tei_model_builder_deploy_wrapper
+ return pysdk_model
+
+ @_capture_telemetry("tei.deploy")
+ def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
+ """Placeholder docstring"""
+ timeout = kwargs.get("model_data_download_timeout")
+ if timeout:
+ self.pysdk_model.env.update({"MODEL_LOADING_TIMEOUT": str(timeout)})
+
+ if "mode" in kwargs and kwargs.get("mode") != self.mode:
+ overwrite_mode = kwargs.get("mode")
+ # mode overwritten by customer during model.deploy()
+ logger.warning(
+ "Deploying in %s Mode, overriding existing configurations set for %s mode",
+ overwrite_mode,
+ self.mode,
+ )
+
+ if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
+ self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
+ elif overwrite_mode == Mode.LOCAL_CONTAINER:
+ self._prepare_for_mode()
+ self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
+ else:
+ raise ValueError("Mode %s is not supported!" % overwrite_mode)
+
+ serializer = self.schema_builder.input_serializer
+ deserializer = self.schema_builder._output_deserializer
+ if self.mode == Mode.LOCAL_CONTAINER:
+ timeout = kwargs.get("model_data_download_timeout")
+
+ predictor = TeiLocalModePredictor(
+ self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
+ )
+
+ self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
+ self.image_uri,
+ timeout if timeout else 1800,
+ None,
+ predictor,
+ self.pysdk_model.env,
+ jumpstart=False,
+ )
+
+ return predictor
+
+ if "mode" in kwargs:
+ del kwargs["mode"]
+ if "role" in kwargs:
+ self.pysdk_model.role = kwargs.get("role")
+ del kwargs["role"]
+
+ # set model_data to uncompressed s3 dict
+ self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
+ self.env_vars.update(env_vars)
+ self.pysdk_model.env.update(self.env_vars)
+
+ # if the weights have been cached via local container mode -> set to offline
+ if str(Mode.LOCAL_CONTAINER) in self.modes:
+ self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"})
+ else:
+ # if has not been built for local container we must use cache
+ # that hosting has write access to.
+ self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp"
+ self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp"
+
+ if "endpoint_logging" not in kwargs:
+ kwargs["endpoint_logging"] = True
+
+ if self.nb_instance_type and "instance_type" not in kwargs:
+ kwargs.update({"instance_type": self.nb_instance_type})
+ elif not self.nb_instance_type and "instance_type" not in kwargs:
+ raise ValueError(
+ "Instance type must be provided when deploying " "to SageMaker Endpoint mode."
+ )
+
+ if "initial_instance_count" not in kwargs:
+ kwargs.update({"initial_instance_count": 1})
+
+ predictor = self._original_deploy(*args, **kwargs)
+
+ predictor.serializer = serializer
+ predictor.deserializer = deserializer
+ return predictor
+
+ def _build_for_hf_tei(self):
+ """Placeholder docstring"""
+ self.nb_instance_type = _get_nb_instance()
+
+ _create_dir_structure(self.model_path)
+ if not hasattr(self, "pysdk_model"):
+ self.env_vars.update({"HF_MODEL_ID": self.model})
+ self.hf_model_config = _get_model_config_properties_from_hf(
+ self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
+ )
+
+ self.pysdk_model = self._create_tei_model()
+
+ if self.mode == Mode.LOCAL_CONTAINER:
+ self._prepare_for_mode()
+
+ return self.pysdk_model
+
+ def _build_for_tei(self):
+ """Placeholder docstring"""
+ self.secret_key = None
+
+ self._set_to_tei()
+
+ self.pysdk_model = self._build_for_hf_tei()
+ return self.pysdk_model
diff --git a/src/sagemaker/serve/builder/tf_serving_builder.py b/src/sagemaker/serve/builder/tf_serving_builder.py
new file mode 100644
index 0000000000..42c548f4e4
--- /dev/null
+++ b/src/sagemaker/serve/builder/tf_serving_builder.py
@@ -0,0 +1,129 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Holds mixin logic to support deployment of Model ID"""
+from __future__ import absolute_import
+import logging
+import os
+from pathlib import Path
+from abc import ABC, abstractmethod
+
+from sagemaker import Session
+from sagemaker.serve.detector.pickler import save_pkl
+from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving
+from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor
+
+logger = logging.getLogger(__name__)
+
+_TF_SERVING_MODEL_BUILDER_ENTRY_POINT = "inference.py"
+_CODE_FOLDER = "code"
+
+
+# pylint: disable=attribute-defined-outside-init, disable=E1101
+class TensorflowServing(ABC):
+ """TensorflowServing build logic for ModelBuilder()"""
+
+ def __init__(self):
+ self.model = None
+ self.serve_settings = None
+ self.sagemaker_session = None
+ self.model_path = None
+ self.dependencies = None
+ self.modes = None
+ self.mode = None
+ self.model_server = None
+ self.image_uri = None
+ self._is_custom_image_uri = False
+ self.image_config = None
+ self.vpc_config = None
+ self._original_deploy = None
+ self.secret_key = None
+ self.engine = None
+ self.pysdk_model = None
+ self.schema_builder = None
+ self.env_vars = None
+
+ @abstractmethod
+ def _prepare_for_mode(self):
+ """Prepare model artifacts based on mode."""
+
+ @abstractmethod
+ def _get_client_translators(self):
+ """Set up client marshaller based on schema builder."""
+
+ def _save_schema_builder(self):
+ """Save schema builder for tensorflow serving."""
+ if not os.path.exists(self.model_path):
+ os.makedirs(self.model_path)
+
+ code_path = Path(self.model_path).joinpath("code")
+ save_pkl(code_path, self.schema_builder)
+
+ def _get_tensorflow_predictor(
+ self, endpoint_name: str, sagemaker_session: Session
+ ) -> TensorFlowPredictor:
+ """Creates a TensorFlowPredictor object"""
+ serializer, deserializer = self._get_client_translators()
+
+ return TensorFlowPredictor(
+ endpoint_name=endpoint_name,
+ sagemaker_session=sagemaker_session,
+ serializer=serializer,
+ deserializer=deserializer,
+ )
+
+ def _validate_for_tensorflow_serving(self):
+ """Validate for tensorflow serving"""
+ if not getattr(self, "_is_mlflow_model", False):
+ raise ValueError("Tensorflow Serving is currently only supported for mlflow models.")
+
+ def _create_tensorflow_model(self):
+ """Creates a TensorFlow model object"""
+ self.pysdk_model = TensorFlowModel(
+ image_uri=self.image_uri,
+ image_config=self.image_config,
+ vpc_config=self.vpc_config,
+ model_data=self.s3_upload_path,
+ role=self.serve_settings.role_arn,
+ env=self.env_vars,
+ sagemaker_session=self.sagemaker_session,
+ predictor_cls=self._get_tensorflow_predictor,
+ )
+
+ self.pysdk_model.mode = self.mode
+ self.pysdk_model.modes = self.modes
+ self.pysdk_model.serve_settings = self.serve_settings
+
+ self._original_deploy = self.pysdk_model.deploy
+ self.pysdk_model.deploy = self._model_builder_deploy_wrapper
+ self._original_register = self.pysdk_model.register
+ self.pysdk_model.register = self._model_builder_register_wrapper
+ self.model_package = None
+ return self.pysdk_model
+
+ def _build_for_tensorflow_serving(self):
+ """Build the model for Tensorflow Serving"""
+ self._validate_for_tensorflow_serving()
+ self._save_schema_builder()
+
+ if not self.image_uri:
+ raise ValueError("image_uri is not set for tensorflow serving")
+
+ self.secret_key = prepare_for_tf_serving(
+ model_path=self.model_path,
+ shared_libs=self.shared_libs,
+ dependencies=self.dependencies,
+ )
+
+ self._prepare_for_mode()
+
+ return self._create_tensorflow_model()
diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py
index 3d84e314df..f84d8f868d 100644
--- a/src/sagemaker/serve/builder/transformers_builder.py
+++ b/src/sagemaker/serve/builder/transformers_builder.py
@@ -78,6 +78,25 @@ def _prepare_for_mode(self):
"""Abstract method"""
def _create_transformers_model(self) -> Type[Model]:
+ """Initializes HF model with or without image_uri"""
+ if self.image_uri is None:
+ pysdk_model = self._get_hf_metadata_create_model()
+ else:
+ pysdk_model = HuggingFaceModel(
+ image_uri=self.image_uri,
+ vpc_config=self.vpc_config,
+ env=self.env_vars,
+ role=self.role_arn,
+ sagemaker_session=self.sagemaker_session,
+ )
+
+ logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
+
+ self._original_deploy = pysdk_model.deploy
+ pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
+ return pysdk_model
+
+ def _get_hf_metadata_create_model(self) -> Type[Model]:
"""Initializes the model after fetching image
1. Get the metadata for deciding framework
@@ -132,19 +151,21 @@ def _create_transformers_model(self) -> Type[Model]:
vpc_config=self.vpc_config,
)
- if self.mode == Mode.LOCAL_CONTAINER:
+ if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER:
self.image_uri = pysdk_model.serving_image_uri(
self.sagemaker_session.boto_region_name, "local"
)
- else:
+ elif not self.image_uri:
self.image_uri = pysdk_model.serving_image_uri(
self.sagemaker_session.boto_region_name, self.instance_type
)
- logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
+ if pysdk_model is None or self.image_uri is None:
+ raise ValueError("PySDK model unable to be created, try overriding image_uri")
+
+ if not pysdk_model.image_uri:
+ pysdk_model.image_uri = self.image_uri
- self._original_deploy = pysdk_model.deploy
- pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
return pysdk_model
@_capture_telemetry("transformers.deploy")
@@ -251,13 +272,14 @@ def _set_instance(self, **kwargs):
if self.mode == Mode.SAGEMAKER_ENDPOINT:
if self.nb_instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.nb_instance_type})
+ logger.info("Setting instance type to %s", self.nb_instance_type)
elif self.instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.instance_type})
+ logger.info("Setting instance type to %s", self.instance_type)
else:
raise ValueError(
"Instance type must be provided when deploying to SageMaker Endpoint mode."
)
- logger.info("Setting instance type to %s", self.instance_type)
def _get_supported_version(self, hf_config, hugging_face_version, base_fw):
"""Uses the hugging face json config to pick supported versions"""
diff --git a/src/sagemaker/serve/mode/local_container_mode.py b/src/sagemaker/serve/mode/local_container_mode.py
index 362a3804de..f040c61c1d 100644
--- a/src/sagemaker/serve/mode/local_container_mode.py
+++ b/src/sagemaker/serve/mode/local_container_mode.py
@@ -11,6 +11,7 @@
import docker
from sagemaker.base_predictor import PredictorBase
+from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.utils.logging_agent import pull_logs
@@ -20,6 +21,7 @@
from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing
from sagemaker.serve.model_server.triton.server import LocalTritonServer
from sagemaker.serve.model_server.tgi.server import LocalTgiServing
+from sagemaker.serve.model_server.tei.server import LocalTeiServing
from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer
from sagemaker.session import Session
@@ -34,7 +36,12 @@
class LocalContainerMode(
- LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalMultiModelServer
+ LocalTorchServe,
+ LocalDJLServing,
+ LocalTritonServer,
+ LocalTgiServing,
+ LocalMultiModelServer,
+ LocalTensorflowServing,
):
"""A class that holds methods to deploy model to a container in local environment"""
@@ -63,6 +70,7 @@ def __init__(
self.container = None
self.secret_key = None
self._ping_container = None
+ self._invoke_serving = None
def load(self, model_path: str = None):
"""Placeholder docstring"""
@@ -141,6 +149,28 @@ def create_server(
env_vars=env_vars if env_vars else self.env_vars,
)
self._ping_container = self._multi_model_server_deep_ping
+ elif self.model_server == ModelServer.TENSORFLOW_SERVING:
+ self._start_tensorflow_serving(
+ client=self.client,
+ image=image,
+ model_path=model_path if model_path else self.model_path,
+ secret_key=secret_key,
+ env_vars=env_vars if env_vars else self.env_vars,
+ )
+ self._ping_container = self._tensorflow_serving_deep_ping
+ elif self.model_server == ModelServer.TEI:
+ tei_serving = LocalTeiServing()
+ tei_serving._start_tei_serving(
+ client=self.client,
+ image=image,
+ model_path=model_path if model_path else self.model_path,
+ secret_key=secret_key,
+ env_vars=env_vars if env_vars else self.env_vars,
+ )
+ tei_serving.schema_builder = self.schema_builder
+ self.container = tei_serving.container
+ self._ping_container = tei_serving._tei_deep_ping
+ self._invoke_serving = tei_serving._invoke_tei_serving
# allow some time for container to be ready
time.sleep(10)
diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py
index 0fdc425b31..b8f1d0529b 100644
--- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py
+++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py
@@ -6,6 +6,8 @@
import logging
from typing import Type
+from sagemaker.serve.model_server.tei.server import SageMakerTeiServing
+from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing
from sagemaker.session import Session
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.spec.inference_spec import InferenceSpec
@@ -24,6 +26,7 @@ class SageMakerEndpointMode(
SageMakerDjlServing,
SageMakerTgiServing,
SageMakerMultiModelServer,
+ SageMakerTensorflowServing,
):
"""Holds the required method to deploy a model to a SageMaker Endpoint"""
@@ -35,6 +38,8 @@ def __init__(self, inference_spec: Type[InferenceSpec], model_server: ModelServe
self.inference_spec = inference_spec
self.model_server = model_server
+ self._tei_serving = SageMakerTeiServing()
+
def load(self, model_path: str):
"""Placeholder docstring"""
path = Path(model_path)
@@ -64,8 +69,9 @@ def prepare(
+ "session to be created or supply `sagemaker_session` into @serve.invoke."
) from e
+ upload_artifacts = None
if self.model_server == ModelServer.TORCHSERVE:
- return self._upload_torchserve_artifacts(
+ upload_artifacts = self._upload_torchserve_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
@@ -74,7 +80,7 @@ def prepare(
)
if self.model_server == ModelServer.TRITON:
- return self._upload_triton_artifacts(
+ upload_artifacts = self._upload_triton_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
secret_key=secret_key,
@@ -83,7 +89,7 @@ def prepare(
)
if self.model_server == ModelServer.DJL_SERVING:
- return self._upload_djl_artifacts(
+ upload_artifacts = self._upload_djl_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
@@ -91,7 +97,7 @@ def prepare(
)
if self.model_server == ModelServer.TGI:
- return self._upload_tgi_artifacts(
+ upload_artifacts = self._upload_tgi_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
s3_model_data_url=s3_model_data_url,
@@ -100,11 +106,31 @@ def prepare(
)
if self.model_server == ModelServer.MMS:
- return self._upload_server_artifacts(
+ upload_artifacts = self._upload_server_artifacts(
+ model_path=model_path,
+ sagemaker_session=sagemaker_session,
+ s3_model_data_url=s3_model_data_url,
+ image=image,
+ )
+
+ if self.model_server == ModelServer.TENSORFLOW_SERVING:
+ upload_artifacts = self._upload_tensorflow_serving_artifacts(
model_path=model_path,
sagemaker_session=sagemaker_session,
+ secret_key=secret_key,
s3_model_data_url=s3_model_data_url,
image=image,
)
+ if self.model_server == ModelServer.TEI:
+ upload_artifacts = self._tei_serving._upload_tei_artifacts(
+ model_path=model_path,
+ sagemaker_session=sagemaker_session,
+ s3_model_data_url=s3_model_data_url,
+ image=image,
+ )
+
+ if upload_artifacts:
+ return upload_artifacts
+
raise ValueError("%s model server is not supported" % self.model_server)
diff --git a/src/sagemaker/serve/model_format/mlflow/constants.py b/src/sagemaker/serve/model_format/mlflow/constants.py
index 00ef76170c..28a3cbdc8d 100644
--- a/src/sagemaker/serve/model_format/mlflow/constants.py
+++ b/src/sagemaker/serve/model_format/mlflow/constants.py
@@ -19,6 +19,12 @@
"py39": "1.13.1",
"py310": "2.2.0",
}
+MODEL_PACKAGE_ARN_REGEX = (
+ r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$"
+)
+MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9]+)*$"
+MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+(/[0-9]+)*$"
+S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+(?:\/[a-zA-Z0-9\-_\/\.]*)?$"
MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH"
MLFLOW_METADATA_FILE = "MLmodel"
MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt"
@@ -34,8 +40,12 @@
"spark": "pyspark",
"onnx": "onnxruntime",
}
-FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = [ # will extend to keras and tf
- "sklearn",
- "pytorch",
- "xgboost",
-]
+TENSORFLOW_SAVED_MODEL_NAME = "saved_model.pb"
+FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = {
+ "sklearn": "sklearn",
+ "pytorch": "pytorch",
+ "xgboost": "xgboost",
+ "tensorflow": "tensorflow",
+ "keras": "tensorflow",
+}
+FLAVORS_DEFAULT_WITH_TF_SERVING = ["keras", "tensorflow"]
diff --git a/src/sagemaker/serve/model_format/mlflow/utils.py b/src/sagemaker/serve/model_format/mlflow/utils.py
index c9a8093a79..c92a6a8a27 100644
--- a/src/sagemaker/serve/model_format/mlflow/utils.py
+++ b/src/sagemaker/serve/model_format/mlflow/utils.py
@@ -13,7 +13,8 @@
"""Holds the util functions used for MLflow model format"""
from __future__ import absolute_import
-from typing import Optional, Dict, Any
+from pathlib import Path
+from typing import Optional, Dict, Any, Union
import yaml
import logging
import shutil
@@ -30,6 +31,8 @@
DEFAULT_PYTORCH_VERSION,
MLFLOW_METADATA_FILE,
MLFLOW_PIP_DEPENDENCY_FILE,
+ FLAVORS_DEFAULT_WITH_TF_SERVING,
+ TENSORFLOW_SAVED_MODEL_NAME,
)
logger = logging.getLogger(__name__)
@@ -44,7 +47,8 @@ def _get_default_model_server_for_mlflow(deployment_flavor: str) -> ModelServer:
Returns:
str: The model server chosen for given model flavor.
"""
- # TODO: implement real logic here based on mlflow flavor
+ if deployment_flavor in FLAVORS_DEFAULT_WITH_TF_SERVING:
+ return ModelServer.TENSORFLOW_SERVING
return ModelServer.TORCHSERVE
@@ -274,7 +278,7 @@ def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> Non
os.makedirs(local_file_dir, exist_ok=True)
# Download the file
- print(f"Downloading {key} to {local_file_path}")
+ logger.info(f"Downloading {key} to {local_file_path}")
s3.download_file(s3_bucket, key, local_file_path)
@@ -344,15 +348,25 @@ def _select_container_for_mlflow_model(
f"specific DLC support. Defaulting to generic image..."
)
return _get_default_image_for_mlflow(python_version, region, instance_type)
- framework_version = _get_framework_version_from_requirements(
- deployment_flavor, requirement_path
- )
+
+ framework_to_use = FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT.get(deployment_flavor)
+ framework_version = _get_framework_version_from_requirements(framework_to_use, requirement_path)
logger.info("Auto-detected deployment flavor is %s", deployment_flavor)
+ logger.info("Auto-detected framework to use is %s", framework_to_use)
logger.info("Auto-detected framework version is %s", framework_version)
+ if framework_version is None:
+ raise ValueError(
+ (
+ "Unable to auto detect framework version. Please provide framework %s as part of the "
+ "requirements.txt file for deployment flavor %s"
+ )
+ % (framework_to_use, deployment_flavor)
+ )
+
casted_versions = (
- _cast_to_compatible_version(deployment_flavor, framework_version)
+ _cast_to_compatible_version(framework_to_use, framework_version)
if framework_version
else (None,)
)
@@ -361,7 +375,7 @@ def _select_container_for_mlflow_model(
for casted_version in casted_versions:
try:
image_uri = image_uris.retrieve(
- framework=deployment_flavor,
+ framework=framework_to_use,
region=region,
version=casted_version,
image_scope="inference",
@@ -392,17 +406,60 @@ def _select_container_for_mlflow_model(
)
-def _validate_input_for_mlflow(model_server: ModelServer) -> None:
+def _validate_input_for_mlflow(model_server: ModelServer, deployment_flavor: str) -> None:
"""Validates arguments provided with mlflow models.
Args:
- model_server (ModelServer): Model server used for orchestrating mlflow model.
+ - deployment_flavor (str): The flavor mlflow model will be deployed with.
Raises:
- ValueError: If model server is not torchserve.
"""
- if model_server != ModelServer.TORCHSERVE:
+ if model_server != ModelServer.TORCHSERVE and model_server != ModelServer.TENSORFLOW_SERVING:
raise ValueError(
f"{model_server} is currently not supported for MLflow Model. "
f"Please choose another model server."
)
+ if (
+ model_server == ModelServer.TENSORFLOW_SERVING
+ and deployment_flavor not in FLAVORS_DEFAULT_WITH_TF_SERVING
+ ):
+ raise ValueError(
+ "Tensorflow Serving is currently only supported for the following "
+ "deployment flavors: {}".format(FLAVORS_DEFAULT_WITH_TF_SERVING)
+ )
+
+
+def _get_saved_model_path_for_tensorflow_and_keras_flavor(model_path: str) -> Optional[str]:
+ """Recursively searches for tensorflow saved model.
+
+ Args:
+ model_path (str): The root directory to start the search from.
+
+ Returns:
+ Optional[str]: The absolute path to the directory containing 'saved_model.pb'.
+ """
+ for dirpath, dirnames, filenames in os.walk(model_path):
+ if TENSORFLOW_SAVED_MODEL_NAME in filenames:
+ return os.path.abspath(dirpath)
+
+ return None
+
+
+def _move_contents(src_dir: Union[str, Path], dest_dir: Union[str, Path]) -> None:
+ """Moves all contents of a source directory to a specified destination directory.
+
+ Args:
+ src_dir (Union[str, Path]): The path to the source directory.
+ dest_dir (Union[str, Path]): The path to the destination directory.
+
+ """
+ _src_dir = Path(os.path.normpath(src_dir))
+ _dest_dir = Path(os.path.normpath(dest_dir))
+
+ _dest_dir.mkdir(parents=True, exist_ok=True)
+
+ for item in _src_dir.iterdir():
+ _dest_path = _dest_dir / item.name
+ shutil.move(str(item), str(_dest_path))
diff --git a/src/sagemaker/serve/model_server/multi_model_server/prepare.py b/src/sagemaker/serve/model_server/multi_model_server/prepare.py
index 7a16cc0a43..7059d9026d 100644
--- a/src/sagemaker/serve/model_server/multi_model_server/prepare.py
+++ b/src/sagemaker/serve/model_server/multi_model_server/prepare.py
@@ -15,7 +15,9 @@
from __future__ import absolute_import
import logging
from pathlib import Path
+from typing import List
+from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts
from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
logger = logging.getLogger(__name__)
@@ -36,3 +38,28 @@ def _create_dir_structure(model_path: str) -> tuple:
_check_docker_disk_usage()
return model_path, code_dir
+
+
+def prepare_mms_js_resources(
+ model_path: str,
+ js_id: str,
+ shared_libs: List[str] = None,
+ dependencies: str = None,
+ model_data: str = None,
+) -> tuple:
+ """Prepare serving when a JumpStart model id is given
+
+ Args:
+ model_path (str) : Argument
+ js_id (str): Argument
+ shared_libs (List[]) : Argument
+ dependencies (str) : Argument
+ model_data (str) : Argument
+
+ Returns:
+ ( str ) :
+
+ """
+ model_path, code_dir = _create_dir_structure(model_path)
+
+ return _copy_jumpstart_artifacts(model_data, js_id, code_dir)
diff --git a/src/sagemaker/serve/model_server/tei/__init__.py b/src/sagemaker/serve/model_server/tei/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/sagemaker/serve/model_server/tei/server.py b/src/sagemaker/serve/model_server/tei/server.py
new file mode 100644
index 0000000000..67fca0e847
--- /dev/null
+++ b/src/sagemaker/serve/model_server/tei/server.py
@@ -0,0 +1,160 @@
+"""Module for Local TEI Serving"""
+
+from __future__ import absolute_import
+
+import requests
+import logging
+from pathlib import Path
+from docker.types import DeviceRequest
+from sagemaker import Session, fw_utils
+from sagemaker.serve.utils.exceptions import LocalModelInvocationException
+from sagemaker.base_predictor import PredictorBase
+from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join
+from sagemaker.s3 import S3Uploader
+from sagemaker.local.utils import get_docker_host
+
+
+MODE_DIR_BINDING = "/opt/ml/model/"
+_SHM_SIZE = "2G"
+_DEFAULT_ENV_VARS = {
+ "TRANSFORMERS_CACHE": "/opt/ml/model/",
+ "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
+}
+
+logger = logging.getLogger(__name__)
+
+
+class LocalTeiServing:
+ """LocalTeiServing class"""
+
+ def _start_tei_serving(
+ self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict
+ ):
+ """Starts a local tei serving container.
+
+ Args:
+ client: Docker client
+ image: Image to use
+ model_path: Path to the model
+ secret_key: Secret key to use for authentication
+ env_vars: Environment variables to set
+ """
+ if env_vars and secret_key:
+ env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key
+
+ self.container = client.containers.run(
+ image,
+ shm_size=_SHM_SIZE,
+ device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
+ network_mode="host",
+ detach=True,
+ auto_remove=True,
+ volumes={
+ Path(model_path).joinpath("code"): {
+ "bind": MODE_DIR_BINDING,
+ "mode": "rw",
+ },
+ },
+ environment=_update_env_vars(env_vars),
+ )
+
+ def _invoke_tei_serving(self, request: object, content_type: str, accept: str):
+ """Invokes a local tei serving container.
+
+ Args:
+ request: Request to send
+ content_type: Content type to use
+ accept: Accept to use
+ """
+ try:
+ response = requests.post(
+ f"http://{get_docker_host()}:8080/invocations",
+ data=request,
+ headers={"Content-Type": content_type, "Accept": accept},
+ timeout=600,
+ )
+ response.raise_for_status()
+ return response.content
+ except Exception as e:
+ raise Exception("Unable to send request to the local container server") from e
+
+ def _tei_deep_ping(self, predictor: PredictorBase):
+ """Checks if the local tei serving container is up and running.
+
+ If the container is not up and running, it will raise an exception.
+ """
+ response = None
+ try:
+ response = predictor.predict(self.schema_builder.sample_input)
+ return (True, response)
+ # pylint: disable=broad-except
+ except Exception as e:
+ if "422 Client Error: Unprocessable Entity for url" in str(e):
+ raise LocalModelInvocationException(str(e))
+ return (False, response)
+
+ return (True, response)
+
+
+class SageMakerTeiServing:
+ """SageMakerTeiServing class"""
+
+ def _upload_tei_artifacts(
+ self,
+ model_path: str,
+ sagemaker_session: Session,
+ s3_model_data_url: str = None,
+ image: str = None,
+ env_vars: dict = None,
+ ):
+ """Uploads the model artifacts to S3.
+
+ Args:
+ model_path: Path to the model
+ sagemaker_session: SageMaker session
+ s3_model_data_url: S3 model data URL
+ image: Image to use
+ env_vars: Environment variables to set
+ """
+ if s3_model_data_url:
+ bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
+ else:
+ bucket, key_prefix = None, None
+
+ code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)
+
+ bucket, code_key_prefix = determine_bucket_and_prefix(
+ bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
+ )
+
+ code_dir = Path(model_path).joinpath("code")
+
+ s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code")
+
+ logger.debug("Uploading TEI Model Resources uncompressed to: %s", s3_location)
+
+ model_data_url = S3Uploader.upload(
+ str(code_dir),
+ s3_location,
+ None,
+ sagemaker_session,
+ )
+
+ model_data = {
+ "S3DataSource": {
+ "CompressionType": "None",
+ "S3DataType": "S3Prefix",
+ "S3Uri": model_data_url + "/",
+ }
+ }
+
+ return (model_data, _update_env_vars(env_vars))
+
+
+def _update_env_vars(env_vars: dict) -> dict:
+ """Placeholder docstring"""
+ updated_env_vars = {}
+ updated_env_vars.update(_DEFAULT_ENV_VARS)
+ if env_vars:
+ updated_env_vars.update(env_vars)
+ return updated_env_vars
diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/__init__.py b/src/sagemaker/serve/model_server/tensorflow_serving/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/inference.py b/src/sagemaker/serve/model_server/tensorflow_serving/inference.py
new file mode 100644
index 0000000000..928278e3c6
--- /dev/null
+++ b/src/sagemaker/serve/model_server/tensorflow_serving/inference.py
@@ -0,0 +1,147 @@
+"""This module is for SageMaker inference.py."""
+
+from __future__ import absolute_import
+import os
+import io
+import json
+import cloudpickle
+import shutil
+import platform
+from pathlib import Path
+from sagemaker.serve.validations.check_integrity import perform_integrity_check
+import logging
+
+logger = logging.getLogger(__name__)
+
+schema_builder = None
+SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs")
+SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl")
+METADATA_PATH = Path(__file__).parent.joinpath("metadata.json")
+
+
+def input_handler(data, context):
+ """Pre-process request input before it is sent to TensorFlow Serving REST API
+
+ Args:
+ data (obj): the request data, in format of dict or string
+ context (Context): an object containing request and configuration details
+ Returns:
+ (dict): a JSON-serializable dict that contains request body and headers
+ """
+ read_data = data.read()
+ deserialized_data = None
+ try:
+ if hasattr(schema_builder, "custom_input_translator"):
+ deserialized_data = schema_builder.custom_input_translator.deserialize(
+ io.BytesIO(read_data), context.request_content_type
+ )
+ else:
+ deserialized_data = schema_builder.input_deserializer.deserialize(
+ io.BytesIO(read_data), context.request_content_type
+ )
+ except Exception as e:
+ logger.error("Encountered error: %s in deserialize_request." % e)
+ raise Exception("Encountered error in deserialize_request.") from e
+
+ try:
+ return json.dumps({"instances": _convert_for_serialization(deserialized_data)})
+ except Exception as e:
+ logger.error(
+ "Encountered error: %s in deserialize_request. "
+ "Deserialized data is not json serializable. " % e
+ )
+ raise Exception("Encountered error in deserialize_request.") from e
+
+
+def output_handler(data, context):
+ """Post-process TensorFlow Serving output before it is returned to the client.
+
+ Args:
+ data (obj): the TensorFlow serving response
+ context (Context): an object containing request and configuration details
+ Returns:
+ (bytes, string): data to return to client, response content type
+ """
+ if data.status_code != 200:
+ raise ValueError(data.content.decode("utf-8"))
+
+ response_content_type = context.accept_header
+ prediction = data.content
+ try:
+ prediction_dict = json.loads(prediction.decode("utf-8"))
+ if hasattr(schema_builder, "custom_output_translator"):
+ return (
+ schema_builder.custom_output_translator.serialize(
+ prediction_dict["predictions"], response_content_type
+ ),
+ response_content_type,
+ )
+ else:
+ return schema_builder.output_serializer.serialize(prediction), response_content_type
+ except Exception as e:
+ logger.error("Encountered error: %s in serialize_response." % e)
+ raise Exception("Encountered error in serialize_response.") from e
+
+
+def _run_preflight_diagnostics():
+ _py_vs_parity_check()
+ _pickle_file_integrity_check()
+
+
+def _py_vs_parity_check():
+ container_py_vs = platform.python_version()
+ local_py_vs = os.getenv("LOCAL_PYTHON")
+
+ if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]:
+ logger.warning(
+ f"The local python version {local_py_vs} differs from the python version "
+ f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior"
+ )
+
+
+def _pickle_file_integrity_check():
+ with open(SERVE_PATH, "rb") as f:
+ buffer = f.read()
+
+ perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH)
+
+
+def _set_up_schema_builder():
+ """Sets up the schema_builder object."""
+ global schema_builder
+ with open(SERVE_PATH, "rb") as f:
+ schema_builder = cloudpickle.load(f)
+
+
+def _set_up_shared_libs():
+ """Sets up the shared libs path."""
+ if SHARED_LIBS_DIR.exists():
+ # before importing, place dynamic linked libraries in shared lib path
+ shutil.copytree(SHARED_LIBS_DIR, "/lib", dirs_exist_ok=True)
+
+
+def _convert_for_serialization(deserialized_data):
+ """Attempt to convert non-serializable objects to a serializable form.
+
+ Args:
+ deserialized_data: The object to convert.
+
+ Returns:
+ The converted object if it was not originally serializable, otherwise the original object.
+ """
+ import numpy as np
+ import pandas as pd
+
+ if isinstance(deserialized_data, np.ndarray):
+ return deserialized_data.tolist()
+ elif isinstance(deserialized_data, pd.DataFrame):
+ return deserialized_data.to_dict(orient="list")
+ elif isinstance(deserialized_data, pd.Series):
+ return deserialized_data.tolist()
+ return deserialized_data
+
+
+# on import, execute
+_run_preflight_diagnostics()
+_set_up_schema_builder()
+_set_up_shared_libs()
diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py b/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py
new file mode 100644
index 0000000000..e9aa4aafff
--- /dev/null
+++ b/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py
@@ -0,0 +1,67 @@
+"""Module for artifacts preparation for tensorflow_serving"""
+
+from __future__ import absolute_import
+from pathlib import Path
+import shutil
+from typing import List, Dict, Any
+
+from sagemaker.serve.model_format.mlflow.utils import (
+ _get_saved_model_path_for_tensorflow_and_keras_flavor,
+ _move_contents,
+)
+from sagemaker.serve.detector.dependency_manager import capture_dependencies
+from sagemaker.serve.validations.check_integrity import (
+ generate_secret_key,
+ compute_hash,
+)
+from sagemaker.remote_function.core.serialization import _MetaData
+
+
+def prepare_for_tf_serving(
+ model_path: str,
+ shared_libs: List[str],
+ dependencies: Dict[str, Any],
+) -> str:
+ """Prepares the model for serving.
+
+ Args:
+ model_path (str): Path to the model directory.
+ shared_libs (List[str]): List of shared libraries.
+ dependencies (Dict[str, Any]): Dictionary of dependencies.
+
+ Returns:
+ str: Secret key.
+ """
+
+ _model_path = Path(model_path)
+ if not _model_path.exists():
+ _model_path.mkdir()
+ elif not _model_path.is_dir():
+ raise Exception("model_dir is not a valid directory")
+
+ code_dir = _model_path.joinpath("code")
+ code_dir.mkdir(exist_ok=True)
+ shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir)
+
+ shared_libs_dir = _model_path.joinpath("shared_libs")
+ shared_libs_dir.mkdir(exist_ok=True)
+ for shared_lib in shared_libs:
+ shutil.copy2(Path(shared_lib), shared_libs_dir)
+
+ capture_dependencies(dependencies=dependencies, work_dir=code_dir)
+
+ saved_model_bundle_dir = _model_path.joinpath("1")
+ saved_model_bundle_dir.mkdir(exist_ok=True)
+ mlflow_saved_model_dir = _get_saved_model_path_for_tensorflow_and_keras_flavor(model_path)
+ if not mlflow_saved_model_dir:
+ raise ValueError("SavedModel is not found for Tensorflow or Keras flavor.")
+ _move_contents(src_dir=mlflow_saved_model_dir, dest_dir=saved_model_bundle_dir)
+
+ secret_key = generate_secret_key()
+ with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
+ buffer = f.read()
+ hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
+ with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
+ metadata.write(_MetaData(hash_value).to_json())
+
+ return secret_key
diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/server.py b/src/sagemaker/serve/model_server/tensorflow_serving/server.py
new file mode 100644
index 0000000000..2392287c61
--- /dev/null
+++ b/src/sagemaker/serve/model_server/tensorflow_serving/server.py
@@ -0,0 +1,139 @@
+"""Module for Local Tensorflow Server"""
+
+from __future__ import absolute_import
+
+import requests
+import logging
+import platform
+from pathlib import Path
+from sagemaker.base_predictor import PredictorBase
+from sagemaker.session import Session
+from sagemaker.serve.utils.exceptions import LocalModelInvocationException
+from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url
+from sagemaker import fw_utils
+from sagemaker.serve.utils.uploader import upload
+from sagemaker.local.utils import get_docker_host
+
+logger = logging.getLogger(__name__)
+
+
+class LocalTensorflowServing:
+ """LocalTensorflowServing class."""
+
+ def _start_tensorflow_serving(
+ self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict
+ ):
+ """Starts a local tensorflow serving container.
+
+ Args:
+ client: Docker client
+ image: Image to use
+ model_path: Path to the model
+ secret_key: Secret key to use for authentication
+ env_vars: Environment variables to set
+ """
+ self.container = client.containers.run(
+ image,
+ "serve",
+ detach=True,
+ auto_remove=True,
+ network_mode="host",
+ volumes={
+ Path(model_path): {
+ "bind": "/opt/ml/model",
+ "mode": "rw",
+ },
+ },
+ environment={
+ "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "SAGEMAKER_SERVE_SECRET_KEY": secret_key,
+ "LOCAL_PYTHON": platform.python_version(),
+ **env_vars,
+ },
+ )
+
+ def _invoke_tensorflow_serving(self, request: object, content_type: str, accept: str):
+ """Invokes a local tensorflow serving container.
+
+ Args:
+ request: Request to send
+ content_type: Content type to use
+ accept: Accept to use
+ """
+ try:
+ response = requests.post(
+ f"http://{get_docker_host()}:8080/invocations",
+ data=request,
+ headers={"Content-Type": content_type, "Accept": accept},
+ timeout=60, # this is what SageMaker Hosting uses as timeout
+ )
+ response.raise_for_status()
+ return response.content
+ except Exception as e:
+ raise Exception("Unable to send request to the local container server") from e
+
+ def _tensorflow_serving_deep_ping(self, predictor: PredictorBase):
+ """Checks if the local tensorflow serving container is up and running.
+
+ If the container is not up and running, it will raise an exception.
+ """
+ response = None
+ try:
+ response = predictor.predict(self.schema_builder.sample_input)
+ return (True, response)
+ # pylint: disable=broad-except
+ except Exception as e:
+ if "422 Client Error: Unprocessable Entity for url" in str(e):
+ raise LocalModelInvocationException(str(e))
+ return (False, response)
+
+ return (True, response)
+
+
+class SageMakerTensorflowServing:
+ """SageMakerTensorflowServing class."""
+
+ def _upload_tensorflow_serving_artifacts(
+ self,
+ model_path: str,
+ sagemaker_session: Session,
+ secret_key: str,
+ s3_model_data_url: str = None,
+ image: str = None,
+ ):
+ """Uploads the model artifacts to S3.
+
+ Args:
+ model_path: Path to the model
+ sagemaker_session: SageMaker session
+ secret_key: Secret key to use for authentication
+ s3_model_data_url: S3 model data URL
+ image: Image to use
+ """
+ if s3_model_data_url:
+ bucket, key_prefix = parse_s3_url(url=s3_model_data_url)
+ else:
+ bucket, key_prefix = None, None
+
+ code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image)
+
+ bucket, code_key_prefix = determine_bucket_and_prefix(
+ bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session
+ )
+
+ logger.debug(
+ "Uploading the model resources to bucket=%s, key_prefix=%s.", bucket, code_key_prefix
+ )
+ s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix)
+ logger.debug("Model resources uploaded to: %s", s3_upload_path)
+
+ env_vars = {
+ "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "SAGEMAKER_REGION": sagemaker_session.boto_region_name,
+ "SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
+ "SAGEMAKER_SERVE_SECRET_KEY": secret_key,
+ "LOCAL_PYTHON": platform.python_version(),
+ }
+ return s3_upload_path, env_vars
diff --git a/src/sagemaker/serve/utils/lineage_constants.py b/src/sagemaker/serve/utils/lineage_constants.py
new file mode 100644
index 0000000000..51be20739f
--- /dev/null
+++ b/src/sagemaker/serve/utils/lineage_constants.py
@@ -0,0 +1,28 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Holds constants used for lineage support"""
+from __future__ import absolute_import
+
+
+LINEAGE_POLLER_INTERVAL_SECS = 15
+LINEAGE_POLLER_MAX_TIMEOUT_SECS = 120
+MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE = "ModelBuilderInputModelData"
+MLFLOW_S3_PATH = "S3"
+MLFLOW_MODEL_PACKAGE_PATH = "ModelPackage"
+MLFLOW_RUN_ID = "MLflowRunId"
+MLFLOW_LOCAL_PATH = "Local"
+MLFLOW_REGISTRY_PATH = "MLflowRegistry"
+ERROR = "Error"
+CODE = "Code"
+CONTRIBUTED_TO = "ContributedTo"
+VALIDATION_EXCEPTION = "ValidationException"
diff --git a/src/sagemaker/serve/utils/lineage_utils.py b/src/sagemaker/serve/utils/lineage_utils.py
new file mode 100644
index 0000000000..3435e138c9
--- /dev/null
+++ b/src/sagemaker/serve/utils/lineage_utils.py
@@ -0,0 +1,277 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Holds the util functions used for lineage tracking"""
+from __future__ import absolute_import
+
+import os
+import time
+import re
+import logging
+from typing import Optional, Union
+
+from botocore.exceptions import ClientError
+
+from sagemaker import Session
+from sagemaker.lineage._api_types import ArtifactSummary
+from sagemaker.lineage.artifact import Artifact
+from sagemaker.lineage.association import Association
+from sagemaker.lineage.query import LineageSourceEnum
+from sagemaker.serve.model_format.mlflow.constants import (
+ MLFLOW_RUN_ID_REGEX,
+ MODEL_PACKAGE_ARN_REGEX,
+ S3_PATH_REGEX,
+ MLFLOW_REGISTRY_PATH_REGEX,
+)
+from sagemaker.serve.utils.lineage_constants import (
+ LINEAGE_POLLER_MAX_TIMEOUT_SECS,
+ LINEAGE_POLLER_INTERVAL_SECS,
+ MLFLOW_S3_PATH,
+ MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
+ MLFLOW_LOCAL_PATH,
+ MLFLOW_MODEL_PACKAGE_PATH,
+ MLFLOW_RUN_ID,
+ MLFLOW_REGISTRY_PATH,
+ CONTRIBUTED_TO,
+ ERROR,
+ CODE,
+ VALIDATION_EXCEPTION,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def _load_artifact_by_source_uri(
+ source_uri: str, artifact_type: str, sagemaker_session: Session
+) -> Optional[ArtifactSummary]:
+ """Load lineage artifact by source uri
+
+ Arguments:
+ source_uri (str): The s3 uri used for uploading transfomred model artifacts.
+ artifact_type (str): The type of the lineage artifact.
+ sagemaker_session (Session): Session object which manages interactions
+ with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
+ function creates one using the default AWS configuration chain.
+
+ Returns:
+ ArtifactSummary: The Artifact Summary for the provided S3 URI.
+ """
+ artifacts = Artifact.list(source_uri=source_uri, sagemaker_session=sagemaker_session)
+ for artifact_summary in artifacts:
+ if artifact_summary.artifact_type == artifact_type:
+ return artifact_summary
+ return None
+
+
+def _poll_lineage_artifact(
+ s3_uri: str, artifact_type: str, sagemaker_session: Session
+) -> Optional[ArtifactSummary]:
+ """Polls lineage artifacts by s3 path.
+
+ Arguments:
+ s3_uri (str): The S3 URI to check for artifacts.
+ artifact_type (str): The type of the lineage artifact.
+ sagemaker_session (Session): Session object which manages interactions
+ with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
+ function creates one using the default AWS configuration chain.
+
+ Returns:
+ Optional[ArtifactSummary]: The artifact summary if found, otherwise None.
+ """
+ logger.info("Polling lineage artifact for model data in %s", s3_uri)
+ start_time = time.time()
+ while time.time() - start_time < LINEAGE_POLLER_MAX_TIMEOUT_SECS:
+ result = _load_artifact_by_source_uri(s3_uri, artifact_type, sagemaker_session)
+ if result is not None:
+ return result
+ time.sleep(LINEAGE_POLLER_INTERVAL_SECS)
+
+
+def _get_mlflow_model_path_type(mlflow_model_path: str) -> str:
+ """Identify mlflow model path type.
+
+ Args:
+ mlflow_model_path (str): The string to be identified.
+
+ Returns:
+ str: Description of what the input string is identified as.
+ """
+ mlflow_rub_id_pattern = MLFLOW_RUN_ID_REGEX
+ mlflow_registry_id_pattern = MLFLOW_REGISTRY_PATH_REGEX
+ sagemaker_arn_pattern = MODEL_PACKAGE_ARN_REGEX
+ s3_pattern = S3_PATH_REGEX
+
+ if re.match(mlflow_rub_id_pattern, mlflow_model_path):
+ return MLFLOW_RUN_ID
+ if re.match(mlflow_registry_id_pattern, mlflow_model_path):
+ return MLFLOW_REGISTRY_PATH
+ if re.match(sagemaker_arn_pattern, mlflow_model_path):
+ return MLFLOW_MODEL_PACKAGE_PATH
+ if re.match(s3_pattern, mlflow_model_path):
+ return MLFLOW_S3_PATH
+ if os.path.exists(mlflow_model_path):
+ return MLFLOW_LOCAL_PATH
+
+ raise ValueError(f"Invalid MLflow model path: {mlflow_model_path}")
+
+
+def _create_mlflow_model_path_lineage_artifact(
+ mlflow_model_path: str,
+ sagemaker_session: Session,
+) -> Optional[Artifact]:
+ """Creates a lineage artifact for the given MLflow model path.
+
+ Args:
+ mlflow_model_path (str): The path to the MLflow model.
+ sagemaker_session (Session): The SageMaker session object.
+
+ Returns:
+ Optional[Artifact]: The created lineage artifact, or None if an error occurred.
+ """
+ _artifact_name = _get_mlflow_model_path_type(mlflow_model_path)
+ properties = dict(
+ model_builder_input_model_data_type=_artifact_name,
+ )
+ try:
+ return Artifact.create(
+ source_uri=mlflow_model_path,
+ artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
+ artifact_name=_artifact_name,
+ properties=properties,
+ sagemaker_session=sagemaker_session,
+ )
+ except ClientError as e:
+ if e.response[ERROR][CODE] == VALIDATION_EXCEPTION:
+ logger.info("Artifact already exists")
+ else:
+ logger.warning("Failed to create mlflow model path lineage artifact: %s", e)
+ raise e
+
+
+def _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact(
+ mlflow_model_path: str,
+ sagemaker_session: Session,
+) -> Optional[Union[Artifact, ArtifactSummary]]:
+ """Retrieves an existing artifact for the given MLflow model path or
+
+ creates a new one if it doesn't exist.
+
+ Args:
+ mlflow_model_path (str): The path to the MLflow model.
+ sagemaker_session (Session): Session object which manages interactions
+ with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
+ function creates one using the default AWS configuration chain.
+
+
+ Returns:
+ Optional[Union[Artifact, ArtifactSummary]]: The existing or newly created artifact,
+ or None if an error occurred.
+ """
+ _loaded_artifact = _load_artifact_by_source_uri(
+ mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session
+ )
+ if _loaded_artifact is not None:
+ return _loaded_artifact
+ return _create_mlflow_model_path_lineage_artifact(
+ mlflow_model_path,
+ sagemaker_session,
+ )
+
+
+def _add_association_between_artifacts(
+ mlflow_model_path_artifact_arn: str,
+ autogenerated_model_data_artifact_arn: str,
+ sagemaker_session: Session,
+) -> None:
+ """Add association between mlflow model path artifact and autogenerated model data artifact.
+
+ Arguments:
+ mlflow_model_path_artifact_arn (str): The mlflow model path artifact.
+ autogenerated_model_data_artifact_arn (str): The autogenerated model data artifact.
+ sagemaker_session (Session): Session object which manages interactions
+ with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
+ function creates one using the default AWS configuration chain.
+ """
+ _association_type = CONTRIBUTED_TO
+ _source_arn = mlflow_model_path_artifact_arn
+ _destination_arn = autogenerated_model_data_artifact_arn
+ try:
+ logger.info(
+ "Adding association with source_arn: "
+ "%s, destination_arn: %s and association_type: %s.",
+ _source_arn,
+ _destination_arn,
+ _association_type,
+ )
+ Association.create(
+ source_arn=_source_arn,
+ destination_arn=_destination_arn,
+ association_type=_association_type,
+ sagemaker_session=sagemaker_session,
+ )
+ except ClientError as e:
+ if e.response[ERROR][CODE] == VALIDATION_EXCEPTION:
+ logger.info("Association already exists")
+ else:
+ raise e
+
+
+def _maintain_lineage_tracking_for_mlflow_model(
+ mlflow_model_path: str,
+ s3_upload_path: str,
+ sagemaker_session: Session,
+) -> None:
+ """Maintains lineage tracking for an MLflow model by creating or retrieving artifacts.
+
+ Args:
+ mlflow_model_path (str): The path to the MLflow model.
+ s3_upload_path (str): The S3 path where the transformed model data is uploaded.
+ sagemaker_session (Session): Session object which manages interactions
+ with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
+ function creates one using the default AWS configuration chain.
+ """
+ artifact_for_transformed_model_data = _poll_lineage_artifact(
+ s3_uri=s3_upload_path,
+ artifact_type=LineageSourceEnum.MODEL_DATA.value,
+ sagemaker_session=sagemaker_session,
+ )
+ if artifact_for_transformed_model_data:
+ mlflow_model_artifact = (
+ _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact(
+ mlflow_model_path=mlflow_model_path,
+ sagemaker_session=sagemaker_session,
+ )
+ )
+ if mlflow_model_artifact:
+ _mlflow_model_artifact_arn = (
+ mlflow_model_artifact.artifact_arn
+ ) # pylint: disable=E1101, disable=C0301
+ _artifact_for_transformed_model_data_arn = (
+ artifact_for_transformed_model_data.artifact_arn
+ ) # pylint: disable=C0301
+ _add_association_between_artifacts(
+ mlflow_model_path_artifact_arn=_mlflow_model_artifact_arn,
+ autogenerated_model_data_artifact_arn=_artifact_for_transformed_model_data_arn,
+ sagemaker_session=sagemaker_session,
+ )
+ else:
+ logger.warning(
+ "Unable to add association between autogenerated lineage "
+ "artifact for transformed model data and mlflow model path"
+ " lineage artifacts."
+ )
+ else:
+ logger.warning(
+ "Lineage artifact for transformed model data is not auto-created within "
+ "%s seconds, skipping creation of lineage artifact for mlflow model path",
+ LINEAGE_POLLER_MAX_TIMEOUT_SECS,
+ )
diff --git a/src/sagemaker/serve/utils/predictors.py b/src/sagemaker/serve/utils/predictors.py
index e0ff8f8ee1..25a995eb48 100644
--- a/src/sagemaker/serve/utils/predictors.py
+++ b/src/sagemaker/serve/utils/predictors.py
@@ -209,6 +209,90 @@ def delete_predictor(self):
self._mode_obj.destroy_server()
+class TeiLocalModePredictor(PredictorBase):
+ """Lightweight Tei predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes"""
+
+ def __init__(
+ self,
+ mode_obj: Type[LocalContainerMode],
+ serializer=JSONSerializer(),
+ deserializer=JSONDeserializer(),
+ ):
+ self._mode_obj = mode_obj
+ self.serializer = serializer
+ self.deserializer = deserializer
+
+ def predict(self, data):
+ """Placeholder docstring"""
+ return [
+ self.deserializer.deserialize(
+ io.BytesIO(
+ self._mode_obj._invoke_serving(
+ self.serializer.serialize(data),
+ self.content_type,
+ self.deserializer.ACCEPT[0],
+ )
+ ),
+ self.content_type,
+ )
+ ]
+
+ @property
+ def content_type(self):
+ """The MIME type of the data sent to the inference endpoint."""
+ return self.serializer.CONTENT_TYPE
+
+ @property
+ def accept(self):
+ """The content type(s) that are expected from the inference endpoint."""
+ return self.deserializer.ACCEPT
+
+ def delete_predictor(self):
+ """Shut down and remove the container that you created in LOCAL_CONTAINER mode"""
+ self._mode_obj.destroy_server()
+
+
+class TensorflowServingLocalPredictor(PredictorBase):
+ """Lightweight predictor for local deployment in LOCAL_CONTAINER modes"""
+
+ # TODO: change mode_obj to union of IN_PROCESS and LOCAL_CONTAINER objs
+ def __init__(
+ self,
+ mode_obj: Type[LocalContainerMode],
+ serializer=IdentitySerializer(),
+ deserializer=BytesDeserializer(),
+ ):
+ self._mode_obj = mode_obj
+ self.serializer = serializer
+ self.deserializer = deserializer
+
+ def predict(self, data):
+ """Placeholder docstring"""
+ return self.deserializer.deserialize(
+ io.BytesIO(
+ self._mode_obj._invoke_tensorflow_serving(
+ self.serializer.serialize(data),
+ self.content_type,
+ self.accept[0],
+ )
+ )
+ )
+
+ @property
+ def content_type(self):
+ """The MIME type of the data sent to the inference endpoint."""
+ return self.serializer.CONTENT_TYPE
+
+ @property
+ def accept(self):
+ """The content type(s) that are expected from the inference endpoint."""
+ return self.deserializer.ACCEPT
+
+ def delete_predictor(self):
+ """Shut down and remove the container that you created in LOCAL_CONTAINER mode"""
+ self._mode_obj.destroy_server()
+
+
def _get_local_mode_predictor(
model_server: ModelServer,
mode_obj: Type[LocalContainerMode],
@@ -223,6 +307,11 @@ def _get_local_mode_predictor(
if model_server == ModelServer.TRITON:
return TritonLocalPredictor(mode_obj=mode_obj)
+ if model_server == ModelServer.TENSORFLOW_SERVING:
+ return TensorflowServingLocalPredictor(
+ mode_obj=mode_obj, serializer=serializer, deserializer=deserializer
+ )
+
raise ValueError("%s model server is not supported yet!" % model_server)
diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py
index 64cbce03e8..342a88c945 100644
--- a/src/sagemaker/serve/utils/telemetry_logger.py
+++ b/src/sagemaker/serve/utils/telemetry_logger.py
@@ -19,7 +19,16 @@
from sagemaker import Session, exceptions
from sagemaker.serve.mode.function_pointers import Mode
+from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH
from sagemaker.serve.utils.exceptions import ModelBuilderException
+from sagemaker.serve.utils.lineage_constants import (
+ MLFLOW_LOCAL_PATH,
+ MLFLOW_S3_PATH,
+ MLFLOW_MODEL_PACKAGE_PATH,
+ MLFLOW_RUN_ID,
+ MLFLOW_REGISTRY_PATH,
+)
+from sagemaker.serve.utils.lineage_utils import _get_mlflow_model_path_type
from sagemaker.serve.utils.types import ModelServer, ImageUriOption
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
from sagemaker.user_agent import SDK_VERSION
@@ -49,6 +58,23 @@
str(ModelServer.DJL_SERVING): 4,
str(ModelServer.TRITON): 5,
str(ModelServer.TGI): 6,
+ str(ModelServer.TEI): 7,
+}
+
+MLFLOW_MODEL_PATH_CODE = {
+ MLFLOW_LOCAL_PATH: 1,
+ MLFLOW_S3_PATH: 2,
+ MLFLOW_MODEL_PACKAGE_PATH: 3,
+ MLFLOW_RUN_ID: 4,
+ MLFLOW_REGISTRY_PATH: 5,
+}
+
+MLFLOW_MODEL_PATH_CODE = {
+ MLFLOW_LOCAL_PATH: 1,
+ MLFLOW_S3_PATH: 2,
+ MLFLOW_MODEL_PACKAGE_PATH: 3,
+ MLFLOW_RUN_ID: 4,
+ MLFLOW_REGISTRY_PATH: 5,
}
@@ -78,6 +104,11 @@ def wrapper(self, *args, **kwargs):
if self.sagemaker_session and self.sagemaker_session.endpoint_arn:
extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}"
+ if getattr(self, "_is_mlflow_model", False):
+ mlflow_model_path = self.model_metadata[MLFLOW_MODEL_PATH]
+ mlflow_model_path_type = _get_mlflow_model_path_type(mlflow_model_path)
+ extra += f"&x-mlflowModelPathType={MLFLOW_MODEL_PATH_CODE[mlflow_model_path_type]}"
+
start_timer = perf_counter()
try:
response = func(self, *args, **kwargs)
diff --git a/src/sagemaker/serve/utils/types.py b/src/sagemaker/serve/utils/types.py
index 661093f249..3ac80aa7ea 100644
--- a/src/sagemaker/serve/utils/types.py
+++ b/src/sagemaker/serve/utils/types.py
@@ -18,6 +18,7 @@ def __str__(self):
DJL_SERVING = 4
TRITON = 5
TGI = 6
+ TEI = 7
class _DjlEngine(Enum):
diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py
index 9e593706c1..bf2a736871 100644
--- a/src/sagemaker/session.py
+++ b/src/sagemaker/session.py
@@ -121,7 +121,7 @@
from sagemaker.deprecations import deprecated_class
from sagemaker.enums import EndpointType
from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig
-from sagemaker.user_agent import prepend_user_agent
+from sagemaker.user_agent import get_user_agent_extra_suffix
from sagemaker.utils import (
name_from_image,
secondary_training_status_changed,
@@ -285,6 +285,7 @@ def _initialize(
Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
Sets the region_name.
"""
+
self.boto_session = boto_session or boto3.DEFAULT_SESSION or boto3.Session()
self._region_name = self.boto_session.region_name
@@ -293,19 +294,30 @@ def _initialize(
"Must setup local AWS configuration with a region supported by SageMaker."
)
- self.sagemaker_client = sagemaker_client or self.boto_session.client("sagemaker")
- prepend_user_agent(self.sagemaker_client)
+ # Make use of user_agent_extra field of the botocore_config object
+ # to append SageMaker Python SDK specific user_agent suffix
+ # to the current User-Agent header value from boto3
+ # This config will also make sure that user_agent never fails to log the User-Agent string
+ # even if boto User-Agent header format is updated in the future
+ # Ref: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ botocore_config = botocore.config.Config(user_agent_extra=get_user_agent_extra_suffix())
+
+ # Create sagemaker_client with the botocore_config object
+ # This config is customized to append SageMaker Python SDK specific user_agent suffix
+ self.sagemaker_client = sagemaker_client or self.boto_session.client(
+ "sagemaker", config=botocore_config
+ )
if sagemaker_runtime_client is not None:
self.sagemaker_runtime_client = sagemaker_runtime_client
else:
- config = botocore.config.Config(read_timeout=80)
+ config = botocore.config.Config(
+ read_timeout=80, user_agent_extra=get_user_agent_extra_suffix()
+ )
self.sagemaker_runtime_client = self.boto_session.client(
"runtime.sagemaker", config=config
)
- prepend_user_agent(self.sagemaker_runtime_client)
-
if sagemaker_featurestore_runtime_client:
self.sagemaker_featurestore_runtime_client = sagemaker_featurestore_runtime_client
else:
@@ -316,8 +328,9 @@ def _initialize(
if sagemaker_metrics_client:
self.sagemaker_metrics_client = sagemaker_metrics_client
else:
- self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics")
- prepend_user_agent(self.sagemaker_metrics_client)
+ self.sagemaker_metrics_client = self.boto_session.client(
+ "sagemaker-metrics", config=botocore_config
+ )
self.s3_client = self.boto_session.client("s3", region_name=self.boto_region_name)
self.s3_resource = self.boto_session.resource("s3", region_name=self.boto_region_name)
@@ -758,6 +771,7 @@ def train( # noqa: C901
environment: Optional[Dict[str, str]] = None,
retry_strategy=None,
remote_debug_config=None,
+ session_chaining_config=None,
):
"""Create an Amazon SageMaker training job.
@@ -877,6 +891,15 @@ def train( # noqa: C901
remote_debug_config = {
"EnableRemoteDebug": True,
}
+ session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``)
+ The dict can contain 'EnableSessionTagChaining'(bool).
+ For example,
+
+ .. code:: python
+
+ session_chaining_config = {
+ "EnableSessionTagChaining": True,
+ }
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
@@ -970,6 +993,7 @@ def train( # noqa: C901
profiler_rule_configs=profiler_rule_configs,
profiler_config=inferred_profiler_config,
remote_debug_config=remote_debug_config,
+ session_chaining_config=session_chaining_config,
environment=environment,
retry_strategy=retry_strategy,
)
@@ -1013,6 +1037,7 @@ def _get_train_request( # noqa: C901
profiler_rule_configs=None,
profiler_config=None,
remote_debug_config=None,
+ session_chaining_config=None,
environment=None,
retry_strategy=None,
):
@@ -1133,6 +1158,15 @@ def _get_train_request( # noqa: C901
remote_debug_config = {
"EnableRemoteDebug": True,
}
+ session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``)
+ The dict can contain 'EnableSessionTagChaining'(bool).
+ For example,
+
+ .. code:: python
+
+ session_chaining_config = {
+ "EnableSessionTagChaining": True,
+ }
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
@@ -1239,6 +1273,9 @@ def _get_train_request( # noqa: C901
if remote_debug_config is not None:
train_request["RemoteDebugConfig"] = remote_debug_config
+ if session_chaining_config is not None:
+ train_request["SessionChainingConfig"] = session_chaining_config
+
if retry_strategy is not None:
train_request["RetryStrategy"] = retry_strategy
diff --git a/src/sagemaker/user_agent.py b/src/sagemaker/user_agent.py
index 8af89696c2..c1b2bcac07 100644
--- a/src/sagemaker/user_agent.py
+++ b/src/sagemaker/user_agent.py
@@ -13,8 +13,6 @@
"""Placeholder docstring"""
from __future__ import absolute_import
-import platform
-import sys
import json
import os
@@ -28,12 +26,6 @@
STUDIO_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"
SDK_VERSION = importlib_metadata.version("sagemaker")
-OS_NAME = platform.system() or "UnresolvedOS"
-OS_VERSION = platform.release() or "UnresolvedOSVersion"
-OS_NAME_VERSION = "{}/{}".format(OS_NAME, OS_VERSION)
-PYTHON_VERSION = "Python/{}.{}.{}".format(
- sys.version_info.major, sys.version_info.minor, sys.version_info.micro
-)
def process_notebook_metadata_file():
@@ -63,45 +55,24 @@ def process_studio_metadata_file():
return None
-def determine_prefix(user_agent=""):
- """Determines the prefix for the user agent string.
+def get_user_agent_extra_suffix():
+ """Get the user agent extra suffix string specific to SageMaker Python SDK
- Args:
- user_agent (str): The user agent string to prepend the prefix to.
+ Adhers to new boto recommended User-Agent 2.0 header format
Returns:
- str: The user agent string with the prefix prepended.
+ str: The user agent extra suffix string to be appended
"""
- prefix = "{}/{}".format(SDK_PREFIX, SDK_VERSION)
-
- if PYTHON_VERSION not in user_agent:
- prefix = "{} {}".format(prefix, PYTHON_VERSION)
-
- if OS_NAME_VERSION not in user_agent:
- prefix = "{} {}".format(prefix, OS_NAME_VERSION)
+ suffix = "lib/{}#{}".format(SDK_PREFIX, SDK_VERSION)
# Get the notebook instance type and prepend it to the user agent string if exists
notebook_instance_type = process_notebook_metadata_file()
if notebook_instance_type:
- prefix = "{} {}/{}".format(prefix, NOTEBOOK_PREFIX, notebook_instance_type)
+ suffix = "{} md/{}#{}".format(suffix, NOTEBOOK_PREFIX, notebook_instance_type)
# Get the studio app type and prepend it to the user agent string if exists
studio_app_type = process_studio_metadata_file()
if studio_app_type:
- prefix = "{} {}/{}".format(prefix, STUDIO_PREFIX, studio_app_type)
-
- return prefix
-
-
-def prepend_user_agent(client):
- """Prepends the user agent string with the SageMaker Python SDK version.
-
- Args:
- client (botocore.client.BaseClient): The client to prepend the user agent string for.
- """
- prefix = determine_prefix(client._client_config.user_agent)
+ suffix = "{} md/{}#{}".format(suffix, STUDIO_PREFIX, studio_app_type)
- if client._client_config.user_agent is None:
- client._client_config.user_agent = prefix
- else:
- client._client_config.user_agent = "{} {}".format(prefix, client._client_config.user_agent)
+ return suffix
diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py
index 0436c0afea..a70ba9eb98 100644
--- a/src/sagemaker/utils.py
+++ b/src/sagemaker/utils.py
@@ -25,6 +25,7 @@
import tarfile
import tempfile
import time
+from functools import lru_cache
from typing import Union, Any, List, Optional, Dict
import json
import abc
@@ -33,10 +34,12 @@
from os.path import abspath, realpath, dirname, normpath, join as joinpath
from importlib import import_module
+
+import boto3
import botocore
from botocore.utils import merge_dicts
from six.moves.urllib import parse
-import pandas as pd
+from six import viewitems
from sagemaker import deprecations
from sagemaker.config import validate_sagemaker_config
@@ -44,6 +47,7 @@
_log_sagemaker_config_single_substitution,
_log_sagemaker_config_merge,
)
+from sagemaker.enums import RoutingStrategy
from sagemaker.session_settings import SessionSettings
from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string
from sagemaker.workflow.entities import PipelineVariable
@@ -1602,44 +1606,80 @@ def can_model_package_source_uri_autopopulate(source_uri: str):
)
-def flatten_dict(source_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]:
- """Flatten a nested dictionary.
+def flatten_dict(
+ d: Dict[str, Any],
+ max_flatten_depth=None,
+) -> Dict[str, Any]:
+ """Flatten a dictionary object.
- Args:
- source_dict (dict): The dictionary to be flattened.
- sep (str): The separator to be used in the flattened dictionary.
- Returns:
- transformed_dict: The flattened dictionary.
+ d (Dict[str, Any]):
+ The dict that will be flattened.
+ max_flatten_depth (Optional[int]):
+ Maximum depth to merge.
"""
- flat_dict_list = pd.json_normalize(source_dict, sep=sep).to_dict(orient="records")
- if flat_dict_list:
- return flat_dict_list[0]
- return {}
+ def tuple_reducer(k1, k2):
+ if k1 is None:
+ return (k2,)
+ return k1 + (k2,)
-def unflatten_dict(source_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]:
- """Unflatten a flattened dictionary back into a nested dictionary.
+ # check max_flatten_depth
+ if max_flatten_depth is not None and max_flatten_depth < 1:
+ raise ValueError("max_flatten_depth should not be less than 1.")
- Args:
- source_dict (dict): The input flattened dictionary.
- sep (str): The separator used in the flattened keys.
+ reducer = tuple_reducer
- Returns:
- transformed_dict: The reconstructed nested dictionary.
+ flat_dict = {}
+
+ def _flatten(_d, depth, parent=None):
+ key_value_iterable = viewitems(_d)
+ has_item = False
+ for key, value in key_value_iterable:
+ has_item = True
+ flat_key = reducer(parent, key)
+ if isinstance(value, dict) and (max_flatten_depth is None or depth < max_flatten_depth):
+ has_child = _flatten(value, depth=depth + 1, parent=flat_key)
+ if has_child:
+ continue
+
+ if flat_key in flat_dict:
+ raise ValueError("duplicated key '{}'".format(flat_key))
+ flat_dict[flat_key] = value
+
+ return has_item
+
+ _flatten(d, depth=1)
+ return flat_dict
+
+
+def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None:
+ """Set a value to a sequence of nested keys."""
+
+ key = keys[0]
+
+ if len(keys) == 1:
+ d[key] = value
+ return
+ if not d:
+ return
+
+ d = d.setdefault(key, {})
+ nested_set_dict(d, keys[1:], value)
+
+
+def unflatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
+ """Unflatten dict-like object.
+
+ d (Dict[str, Any]) :
+ The dict that will be unflattened.
"""
- if not source_dict:
- return {}
- result = {}
- for key, value in source_dict.items():
- keys = key.split(sep)
- current = result
- for k in keys[:-1]:
- if k not in current:
- current[k] = {}
- current = current[k] if current[k] is not None else current
- current[keys[-1]] = value
- return result
+ unflattened_dict = {}
+ for flat_key, value in viewitems(d):
+ key_tuple = flat_key
+ nested_set_dict(unflattened_dict, key_tuple, value)
+
+ return unflattened_dict
def deep_override_dict(
@@ -1655,3 +1695,105 @@ def deep_override_dict(
)
flattened_dict1.update(flattened_dict2)
return unflatten_dict(flattened_dict1) if flattened_dict1 else {}
+
+
+def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
+ """Resolve Routing Config
+
+ Args:
+ routing_config (Optional[Dict[str, Any]]): The routing config.
+
+ Returns:
+ Optional[Dict[str, Any]]: The resolved routing config.
+
+ Raises:
+ ValueError: If the RoutingStrategy is invalid.
+ """
+
+ if routing_config:
+ routing_strategy = routing_config.get("RoutingStrategy", None)
+ if routing_strategy:
+ if isinstance(routing_strategy, RoutingStrategy):
+ return {"RoutingStrategy": routing_strategy.name}
+ if isinstance(routing_strategy, str) and (
+ routing_strategy.upper() == RoutingStrategy.RANDOM.name
+ or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name
+ ):
+ return {"RoutingStrategy": routing_strategy.upper()}
+ raise ValueError(
+ "RoutingStrategy must be either RoutingStrategy.RANDOM "
+ "or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS"
+ )
+ return None
+
+
+@lru_cache
+def get_instance_rate_per_hour(
+ instance_type: str,
+ region: str,
+) -> Optional[Dict[str, str]]:
+ """Gets instance rate per hour for the given instance type.
+
+ Args:
+ instance_type (str): The instance type.
+ region (str): The region.
+ Returns:
+ Optional[Dict[str, str]]: Instance rate per hour.
+ Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.125'}.
+
+ Raises:
+ Exception: An exception is raised if
+ the IAM role is not authorized to perform pricing:GetProducts.
+ or unexpected event happened.
+ """
+ region_name = "us-east-1"
+ if region.startswith("eu") or region.startswith("af"):
+ region_name = "eu-central-1"
+ elif region.startswith("ap") or region.startswith("cn"):
+ region_name = "ap-south-1"
+
+ pricing_client: boto3.client = boto3.client("pricing", region_name=region_name)
+ res = pricing_client.get_products(
+ ServiceCode="AmazonSageMaker",
+ Filters=[
+ {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type},
+ {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"},
+ {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region},
+ ],
+ )
+
+ price_list = res.get("PriceList", [])
+ if len(price_list) > 0:
+ price_data = price_list[0]
+ if isinstance(price_data, str):
+ price_data = json.loads(price_data)
+
+ instance_rate_per_hour = extract_instance_rate_per_hour(price_data)
+ if instance_rate_per_hour is not None:
+ return instance_rate_per_hour
+ raise Exception(f"Unable to get instance rate per hour for instance type: {instance_type}.")
+
+
+def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[str, str]]:
+ """Extract instance rate per hour for the given Price JSON data.
+
+ Args:
+ price_data (Dict[str, Any]): The Price JSON data.
+ Returns:
+ Optional[Dict[str, str], None]: Instance rate per hour.
+ """
+
+ if price_data is not None:
+ price_dimensions = price_data.get("terms", {}).get("OnDemand", {}).values()
+ for dimension in price_dimensions:
+ for price in dimension.get("priceDimensions", {}).values():
+ for currency in price.get("pricePerUnit", {}).keys():
+ value = price.get("pricePerUnit", {}).get(currency)
+ if value is not None:
+ value = str(round(float(value), 3))
+ return {
+ "unit": f"{currency}/{price.get('unit', 'Hrs')}",
+ "value": value,
+ "name": "Instance Rate",
+ }
+ return None
diff --git a/tests/conftest.py b/tests/conftest.py
index 0309781e7b..7bab05dfb3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -253,7 +253,9 @@ def mxnet_eia_latest_py_version():
@pytest.fixture(scope="module", params=["py2", "py3"])
def pytorch_training_py_version(pytorch_training_version, request):
- if Version(pytorch_training_version) >= Version("2.0"):
+ if Version(pytorch_training_version) >= Version("2.3"):
+ return "py311"
+ elif Version(pytorch_training_version) >= Version("2.0"):
return "py310"
elif Version(pytorch_training_version) >= Version("1.13"):
return "py39"
diff --git a/tests/data/serve_resources/mlflow/pytorch/requirements.txt b/tests/data/serve_resources/mlflow/pytorch/requirements.txt
index 9848949b0f..895e2173bf 100644
--- a/tests/data/serve_resources/mlflow/pytorch/requirements.txt
+++ b/tests/data/serve_resources/mlflow/pytorch/requirements.txt
@@ -1,4 +1,4 @@
-mlflow==2.10.2
+mlflow==2.12.1
astunparse==1.6.3
cffi==1.16.0
cloudpickle==2.2.1
@@ -10,7 +10,7 @@ opt-einsum==3.3.0
packaging==21.3
pandas==2.2.1
pyyaml==6.0.1
-requests==2.31.0
+requests==2.32.2
torch==2.0.1
torchvision==0.15.2
-tqdm==4.66.2
+tqdm==4.66.3
diff --git a/tests/data/serve_resources/mlflow/tensorflow/MLmodel b/tests/data/serve_resources/mlflow/tensorflow/MLmodel
new file mode 100644
index 0000000000..f00412149d
--- /dev/null
+++ b/tests/data/serve_resources/mlflow/tensorflow/MLmodel
@@ -0,0 +1,17 @@
+artifact_path: model
+flavors:
+ python_function:
+ env:
+ conda: conda.yaml
+ virtualenv: python_env.yaml
+ loader_module: mlflow.tensorflow
+ python_version: 3.10.13
+ tensorflow:
+ code: null
+ model_type: tf2-module
+ saved_model_dir: tf2model
+mlflow_version: 2.11.1
+model_size_bytes: 23823
+model_uuid: 40d2323944294fce898d8693455f60e8
+run_id: 592132312fb84935b201de2c027c54c6
+utc_time_created: '2024-04-01 19:47:15.396517'
diff --git a/tests/data/serve_resources/mlflow/tensorflow/conda.yaml b/tests/data/serve_resources/mlflow/tensorflow/conda.yaml
new file mode 100644
index 0000000000..90d8c300a0
--- /dev/null
+++ b/tests/data/serve_resources/mlflow/tensorflow/conda.yaml
@@ -0,0 +1,11 @@
+channels:
+- conda-forge
+dependencies:
+- python=3.10.13
+- pip<=23.3.1
+- pip:
+ - mlflow==2.11.1
+ - cloudpickle==2.2.1
+ - numpy==1.26.4
+ - tensorflow==2.16.1
+name: mlflow-env
diff --git a/tests/data/serve_resources/mlflow/tensorflow/python_env.yaml b/tests/data/serve_resources/mlflow/tensorflow/python_env.yaml
new file mode 100644
index 0000000000..9e09178b6c
--- /dev/null
+++ b/tests/data/serve_resources/mlflow/tensorflow/python_env.yaml
@@ -0,0 +1,7 @@
+python: 3.10.13
+build_dependencies:
+- pip==23.3.1
+- setuptools==68.2.2
+- wheel==0.41.2
+dependencies:
+- -r requirements.txt
diff --git a/tests/data/serve_resources/mlflow/tensorflow/registered_model_meta b/tests/data/serve_resources/mlflow/tensorflow/registered_model_meta
new file mode 100644
index 0000000000..5423c0e6c7
--- /dev/null
+++ b/tests/data/serve_resources/mlflow/tensorflow/registered_model_meta
@@ -0,0 +1,2 @@
+model_name: model
+model_version: '2'
diff --git a/tests/data/serve_resources/mlflow/tensorflow/requirements.txt b/tests/data/serve_resources/mlflow/tensorflow/requirements.txt
new file mode 100644
index 0000000000..d4ff5b4782
--- /dev/null
+++ b/tests/data/serve_resources/mlflow/tensorflow/requirements.txt
@@ -0,0 +1,4 @@
+mlflow==2.12.1
+cloudpickle==2.2.1
+numpy==1.26.4
+tensorflow==2.16.1
diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/fingerprint.pb b/tests/data/serve_resources/mlflow/tensorflow/tf2model/fingerprint.pb
new file mode 100644
index 0000000000..ba1e240ba5
--- /dev/null
+++ b/tests/data/serve_resources/mlflow/tensorflow/tf2model/fingerprint.pb
@@ -0,0 +1 @@
+´™ÊÁ´¼ÜÌÌĽð˜Œëón”™¯’/ öÊ¢„ÑòÁ„‰(Œ½£‘Èð32
\ No newline at end of file
diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/saved_model.pb b/tests/data/serve_resources/mlflow/tensorflow/tf2model/saved_model.pb
new file mode 100644
index 0000000000..e48f2b59cc
Binary files /dev/null and b/tests/data/serve_resources/mlflow/tensorflow/tf2model/saved_model.pb differ
diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.data-00000-of-00001 b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.data-00000-of-00001
new file mode 100644
index 0000000000..575da96282
Binary files /dev/null and b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.data-00000-of-00001 differ
diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.index b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.index
new file mode 100644
index 0000000000..57646ac350
Binary files /dev/null and b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.index differ
diff --git a/tests/data/serve_resources/mlflow/xgboost/requirements.txt b/tests/data/serve_resources/mlflow/xgboost/requirements.txt
index 8150c9fedf..18d687aec6 100644
--- a/tests/data/serve_resources/mlflow/xgboost/requirements.txt
+++ b/tests/data/serve_resources/mlflow/xgboost/requirements.txt
@@ -1,4 +1,4 @@
-mlflow==2.11.1
+mlflow==2.12.1
lz4==4.3.2
numpy==1.24.4
pandas==2.0.3
diff --git a/tests/data/sip/training.py b/tests/data/sip/training.py
index 9b643b2c35..f13b8b3533 100644
--- a/tests/data/sip/training.py
+++ b/tests/data/sip/training.py
@@ -73,7 +73,8 @@ def main():
)
model_dir = os.environ.get("SM_MODEL_DIR")
- pkl.dump(bst, open(model_dir + "/model.bin", "wb"))
+ with open(model_dir + "/model.bin", "wb") as f:
+ pkl.dump(bst, f)
if __name__ == "__main__":
diff --git a/tests/integ/sagemaker/conftest.py b/tests/integ/sagemaker/conftest.py
index 2dc9f7df4d..043b0c703e 100644
--- a/tests/integ/sagemaker/conftest.py
+++ b/tests/integ/sagemaker/conftest.py
@@ -176,7 +176,7 @@ def conda_env_yml():
os.remove(conda_yml_file_name)
-def _build_container(sagemaker_session, py_version, docker_templete):
+def _build_container(sagemaker_session, py_version, docker_template):
"""Build a dummy test container locally and push a container to an ecr repo"""
region = sagemaker_session.boto_region_name
@@ -189,9 +189,9 @@ def _build_container(sagemaker_session, py_version, docker_templete):
print("building source archive...")
source_archive = _generate_sagemaker_sdk_tar(tmpdir)
with open(os.path.join(tmpdir, "Dockerfile"), "w") as file:
- file.writelines(
- docker_templete.format(py_version=py_version, source_archive=source_archive)
- )
+ content = docker_template.format(py_version=py_version, source_archive=source_archive)
+ print(f"Dockerfile contents: \n{content}\n")
+ file.writelines(content)
docker_client = docker.from_env()
@@ -209,6 +209,7 @@ def _build_container(sagemaker_session, py_version, docker_templete):
raise
if _is_repository_exists(ecr_client, REPO_NAME):
+ print("pushing to session configured account id!")
sts_client = sagemaker_session.boto_session.client(
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
)
@@ -218,6 +219,7 @@ def _build_container(sagemaker_session, py_version, docker_templete):
account_id, sagemaker_session.boto_region_name, REPO_NAME, image_tag
)
else:
+ print(f"pushing to account id: {REPO_ACCOUNT_ID}")
ecr_image = _ecr_image_uri(
REPO_ACCOUNT_ID,
sagemaker_session.boto_region_name,
@@ -232,7 +234,7 @@ def _build_container(sagemaker_session, py_version, docker_templete):
return ecr_image
-def _build_auto_capture_client_container(py_version, docker_templete):
+def _build_auto_capture_client_container(py_version, docker_template):
"""Build a test docker container that will act as a client for auto_capture tests"""
with _tmpdir() as tmpdir:
print("building docker image locally in ", tmpdir)
@@ -240,9 +242,9 @@ def _build_auto_capture_client_container(py_version, docker_templete):
source_archive = _generate_sdk_tar_with_public_version(tmpdir)
_move_auto_capture_test_file(tmpdir)
with open(os.path.join(tmpdir, "Dockerfile"), "w") as file:
- file.writelines(
- docker_templete.format(py_version=py_version, source_archive=source_archive)
- )
+ content = docker_template.format(py_version=py_version, source_archive=source_archive)
+ print(f"Dockerfile contents: \n{content}\n")
+ file.writelines(content)
docker_client = docker.from_env()
@@ -276,11 +278,14 @@ def _generate_sagemaker_sdk_tar(destination_folder):
"""
Run setup.py sdist to generate the PySDK tar file
"""
- subprocess.run(
- f"python3 setup.py egg_info --egg-base {destination_folder} sdist -d {destination_folder} -k",
- shell=True,
- check=True,
- )
+ command = f"python3 setup.py egg_info --egg-base {destination_folder} sdist -d {destination_folder} -k --verbose"
+ print(f"Running command: {command}")
+ result = subprocess.run(command, shell=True, check=True, capture_output=True)
+ if result.returncode != 0:
+ print(f"Command failed with return code: {result.returncode}")
+
+ print(f"Standard output: {result.stdout.decode()}")
+ print(f"Standard error: {result.stderr.decode()}")
destination_folder_contents = os.listdir(destination_folder)
source_archive = [file for file in destination_folder_contents if file.endswith("tar.gz")][0]
diff --git a/tests/integ/sagemaker/jumpstart/constants.py b/tests/integ/sagemaker/jumpstart/constants.py
index f5ffbf7a3a..b839866b1f 100644
--- a/tests/integ/sagemaker/jumpstart/constants.py
+++ b/tests/integ/sagemaker/jumpstart/constants.py
@@ -48,6 +48,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"),
("meta-textgeneration-llama-2-7b", "3.*"): ("training-datasets/sec_amazon/"),
+ ("meta-textgeneration-llama-2-7b", "4.*"): ("training-datasets/sec_amazon/"),
("meta-textgenerationneuron-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
}
diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py
index a839a293c5..0da64ecf05 100644
--- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py
+++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py
@@ -140,7 +140,7 @@ def test_gated_model_training_v1(setup):
def test_gated_model_training_v2(setup):
model_id = "meta-textgeneration-llama-2-7b"
- model_version = "3.*" # model artifacts retrieved from jumpstart-private-cache-* buckets
+ model_version = "4.*" # model artifacts retrieved from jumpstart-private-cache-* buckets
estimator = JumpStartEstimator(
model_id=model_id,
@@ -150,6 +150,7 @@ def test_gated_model_training_v2(setup):
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
environment={"accept_eula": "true"},
max_run=259200, # avoid exceeding resource limits
+ tolerate_vulnerable_model=True, # tolerate old version of model
)
# uses ml.g5.12xlarge instance
diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py
index 4b9e816623..2fdf735d21 100644
--- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py
+++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py
@@ -142,8 +142,8 @@ def test_graph_visualize(sagemaker_session, extract_data_from_html):
lq_result.visualize(path="testGraph.html")
# check generated graph info
- fo = open("testGraph.html", "r")
- lines = fo.readlines()
+ with open("testGraph.html", "r") as fo:
+ lines = fo.readlines()
for line in lines:
if "nodes = " in line:
node = line
diff --git a/tests/integ/sagemaker/serve/constants.py b/tests/integ/sagemaker/serve/constants.py
index 794f7333a3..d5e7a56f83 100644
--- a/tests/integ/sagemaker/serve/constants.py
+++ b/tests/integ/sagemaker/serve/constants.py
@@ -32,6 +32,7 @@
DATA_DIR, "serve_resources", "mlflow", "pytorch"
)
XGBOOST_MLFLOW_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "mlflow", "xgboost")
+TENSORFLOW_MLFLOW_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "mlflow", "tensorflow")
TF_EFFICIENT_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "tensorflow")
HF_DIR = os.path.join(DATA_DIR, "serve_resources", "hf")
diff --git a/tests/integ/sagemaker/serve/test_serve_js_happy.py b/tests/integ/sagemaker/serve/test_serve_js_happy.py
index 7835c8ae3c..ad0527fcc0 100644
--- a/tests/integ/sagemaker/serve/test_serve_js_happy.py
+++ b/tests/integ/sagemaker/serve/test_serve_js_happy.py
@@ -34,6 +34,14 @@
JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16"
ROLE_NAME = "SageMakerRole"
+SAMPLE_MMS_PROMPT = [
+ "How cute your dog is!",
+ "Your dog is so cute.",
+ "The mitochondria is the powerhouse of the cell.",
+]
+SAMPLE_MMS_RESPONSE = {"embedding": []}
+JS_MMS_MODEL_ID = "huggingface-sentencesimilarity-bge-m3"
+
@pytest.fixture
def happy_model_builder(sagemaker_session):
@@ -46,6 +54,17 @@ def happy_model_builder(sagemaker_session):
)
+@pytest.fixture
+def happy_mms_model_builder(sagemaker_session):
+ iam_client = sagemaker_session.boto_session.client("iam")
+ return ModelBuilder(
+ model=JS_MMS_MODEL_ID,
+ schema_builder=SchemaBuilder(SAMPLE_MMS_PROMPT, SAMPLE_MMS_RESPONSE),
+ role_arn=iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"],
+ sagemaker_session=sagemaker_session,
+ )
+
+
@pytest.mark.skipif(
PYTHON_VERSION_IS_NOT_310,
reason="The goal of these test are to test the serving components of our feature",
@@ -75,3 +94,34 @@ def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type):
)
if caught_ex:
raise caught_ex
+
+
+@pytest.mark.skipif(
+ PYTHON_VERSION_IS_NOT_310,
+ reason="The goal of these test are to test the serving components of our feature",
+)
+@pytest.mark.slow_test
+def test_happy_mms_sagemaker_endpoint(happy_mms_model_builder, gpu_instance_type):
+ logger.info("Running in SAGEMAKER_ENDPOINT mode...")
+ caught_ex = None
+ model = happy_mms_model_builder.build()
+
+ with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
+ try:
+ logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
+ predictor = model.deploy(instance_type=gpu_instance_type, endpoint_logging=False)
+ logger.info("Endpoint successfully deployed.")
+
+ updated_sample_input = happy_mms_model_builder.schema_builder.sample_input
+
+ predictor.predict(updated_sample_input)
+ except Exception as e:
+ caught_ex = e
+ finally:
+ cleanup_model_resources(
+ sagemaker_session=happy_mms_model_builder.sagemaker_session,
+ model_name=model.name,
+ endpoint_name=model.endpoint_name,
+ )
+ if caught_ex:
+ raise caught_ex
diff --git a/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py b/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py
index e7ebd9c5bf..e6beb76d6e 100644
--- a/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py
+++ b/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py
@@ -19,6 +19,8 @@
import io
import numpy as np
+from sagemaker.lineage.artifact import Artifact
+from sagemaker.lineage.association import Association
from sagemaker.s3 import S3Uploader
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator
@@ -35,6 +37,10 @@
from tests.integ.utils import cleanup_model_resources
import logging
+from sagemaker.serve.utils.lineage_constants import (
+ MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
+)
+
logger = logging.getLogger(__name__)
ROLE_NAME = "SageMakerRole"
@@ -205,6 +211,19 @@ def test_happy_pytorch_sagemaker_endpoint_with_torch_serve(
predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1)
logger.info("Endpoint successfully deployed.")
predictor.predict(test_image)
+ model_data_artifact = None
+ for artifact in Artifact.list(
+ source_uri=model_builder.s3_upload_path, sagemaker_session=sagemaker_session
+ ):
+ model_data_artifact = artifact
+ for association in Association.list(
+ destination_arn=model_data_artifact.artifact_arn,
+ sagemaker_session=sagemaker_session,
+ ):
+ assert (
+ association.source_type == MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE
+ )
+ break
except Exception as e:
caught_ex = e
finally:
@@ -214,9 +233,4 @@ def test_happy_pytorch_sagemaker_endpoint_with_torch_serve(
endpoint_name=model.endpoint_name,
)
if caught_ex:
- logger.exception(caught_ex)
- ignore_if_worker_dies = "Worker died." in str(caught_ex)
- # https://github.com/pytorch/serve/issues/3032
- assert (
- ignore_if_worker_dies
- ), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test"
+ raise caught_ex
diff --git a/tests/integ/sagemaker/serve/test_serve_mlflow_tensorflow_flavor_happy.py b/tests/integ/sagemaker/serve/test_serve_mlflow_tensorflow_flavor_happy.py
new file mode 100644
index 0000000000..c25cbd7e18
--- /dev/null
+++ b/tests/integ/sagemaker/serve/test_serve_mlflow_tensorflow_flavor_happy.py
@@ -0,0 +1,175 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import pytest
+import io
+import numpy as np
+
+from sagemaker.lineage.artifact import Artifact
+from sagemaker.lineage.association import Association
+from sagemaker.s3 import S3Uploader
+from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
+from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator
+import tensorflow as tf
+from sklearn.datasets import fetch_california_housing
+
+
+from tests.integ.sagemaker.serve.constants import (
+ TENSORFLOW_MLFLOW_RESOURCE_DIR,
+ SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
+ PYTHON_VERSION_IS_NOT_310,
+)
+from tests.integ.timeout import timeout
+from tests.integ.utils import cleanup_model_resources
+import logging
+
+from sagemaker.serve.utils.lineage_constants import (
+ MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
+)
+
+logger = logging.getLogger(__name__)
+
+ROLE_NAME = "SageMakerRole"
+
+
+@pytest.fixture
+def test_data():
+ dataset = fetch_california_housing(as_frame=True)["frame"]
+ dataset = dataset.dropna()
+ dataset_tf = tf.convert_to_tensor(dataset, dtype=tf.float32)
+ dataset_tf = dataset_tf[:50]
+ x_test, y_test = dataset_tf[:, :-1], dataset_tf[:, -1]
+ return x_test, y_test
+
+
+@pytest.fixture
+def custom_request_translator():
+ class MyRequestTranslator(CustomPayloadTranslator):
+ def serialize_payload_to_bytes(self, payload: object) -> bytes:
+ return self._convert_numpy_to_bytes(payload)
+
+ def deserialize_payload_from_stream(self, stream) -> object:
+ np_array = np.load(io.BytesIO(stream.read()))
+ return np_array
+
+ def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes:
+ buffer = io.BytesIO()
+ np.save(buffer, np_array)
+ return buffer.getvalue()
+
+ return MyRequestTranslator()
+
+
+@pytest.fixture
+def custom_response_translator():
+ class MyResponseTranslator(CustomPayloadTranslator):
+ def serialize_payload_to_bytes(self, payload: object) -> bytes:
+ import numpy as np
+
+ return self._convert_numpy_to_bytes(np.array(payload))
+
+ def deserialize_payload_from_stream(self, stream) -> object:
+ import tensorflow as tf
+
+ return tf.convert_to_tensor(np.load(io.BytesIO(stream.read())))
+
+ def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes:
+ buffer = io.BytesIO()
+ np.save(buffer, np_array)
+ return buffer.getvalue()
+
+ return MyResponseTranslator()
+
+
+@pytest.fixture
+def tensorflow_schema_builder(custom_request_translator, custom_response_translator, test_data):
+ input_data, output_data = test_data
+ return SchemaBuilder(
+ sample_input=input_data,
+ sample_output=output_data,
+ input_translator=custom_request_translator,
+ output_translator=custom_response_translator,
+ )
+
+
+@pytest.mark.skipif(
+ PYTHON_VERSION_IS_NOT_310,
+ reason="The goal of these test are to test the serving components of our feature",
+)
+def test_happy_tensorflow_sagemaker_endpoint_with_tensorflow_serving(
+ sagemaker_session,
+ tensorflow_schema_builder,
+ cpu_instance_type,
+ test_data,
+):
+ logger.info("Running in SAGEMAKER_ENDPOINT mode...")
+ caught_ex = None
+
+ iam_client = sagemaker_session.boto_session.client("iam")
+ role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
+
+ model_artifacts_uri = "s3://{}/{}/{}/{}".format(
+ sagemaker_session.default_bucket(),
+ "model_builder_integ_test",
+ "mlflow",
+ "tensorflow",
+ )
+
+ model_path = S3Uploader.upload(
+ local_path=TENSORFLOW_MLFLOW_RESOURCE_DIR,
+ desired_s3_uri=model_artifacts_uri,
+ sagemaker_session=sagemaker_session,
+ )
+
+ model_builder = ModelBuilder(
+ mode=Mode.SAGEMAKER_ENDPOINT,
+ schema_builder=tensorflow_schema_builder,
+ role_arn=role_arn,
+ sagemaker_session=sagemaker_session,
+ model_metadata={"MLFLOW_MODEL_PATH": model_path},
+ )
+
+ model = model_builder.build(sagemaker_session=sagemaker_session)
+
+ with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
+ try:
+ test_x, _ = test_data
+ logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
+ predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1)
+ logger.info("Endpoint successfully deployed.")
+ predictor.predict(test_x)
+ model_data_artifact = None
+ for artifact in Artifact.list(
+ source_uri=model_builder.s3_upload_path, sagemaker_session=sagemaker_session
+ ):
+ model_data_artifact = artifact
+ for association in Association.list(
+ destination_arn=model_data_artifact.artifact_arn,
+ sagemaker_session=sagemaker_session,
+ ):
+ assert (
+ association.source_type == MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE
+ )
+ break
+
+ except Exception as e:
+ caught_ex = e
+ finally:
+ cleanup_model_resources(
+ sagemaker_session=model_builder.sagemaker_session,
+ model_name=model.name,
+ endpoint_name=model.endpoint_name,
+ )
+ if caught_ex:
+ raise caught_ex
diff --git a/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py b/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py
index 5a73942afe..7b47440a97 100644
--- a/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py
+++ b/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py
@@ -16,6 +16,8 @@
import io
import numpy as np
+from sagemaker.lineage.artifact import Artifact
+from sagemaker.lineage.association import Association
from sagemaker.s3 import S3Uploader
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator
@@ -32,6 +34,10 @@
from tests.integ.utils import cleanup_model_resources
import logging
+from sagemaker.serve.utils.lineage_constants import (
+ MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
+)
+
logger = logging.getLogger(__name__)
ROLE_NAME = "SageMakerRole"
@@ -187,6 +193,19 @@ def test_happy_xgboost_sagemaker_endpoint_with_torch_serve(
predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1)
logger.info("Endpoint successfully deployed.")
predictor.predict(test_x)
+ model_data_artifact = None
+ for artifact in Artifact.list(
+ source_uri=model_builder.s3_upload_path, sagemaker_session=sagemaker_session
+ ):
+ model_data_artifact = artifact
+ for association in Association.list(
+ destination_arn=model_data_artifact.artifact_arn,
+ sagemaker_session=sagemaker_session,
+ ):
+ assert (
+ association.source_type == MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE
+ )
+ break
except Exception as e:
caught_ex = e
finally:
@@ -196,9 +215,4 @@ def test_happy_xgboost_sagemaker_endpoint_with_torch_serve(
endpoint_name=model.endpoint_name,
)
if caught_ex:
- logger.exception(caught_ex)
- ignore_if_worker_dies = "Worker died." in str(caught_ex)
- # https://github.com/pytorch/serve/issues/3032
- assert (
- ignore_if_worker_dies
- ), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test"
+ raise caught_ex
diff --git a/tests/integ/sagemaker/serve/test_serve_tei.py b/tests/integ/sagemaker/serve/test_serve_tei.py
new file mode 100644
index 0000000000..5cf1a3635c
--- /dev/null
+++ b/tests/integ/sagemaker/serve/test_serve_tei.py
@@ -0,0 +1,87 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import pytest
+from sagemaker.serve.builder.schema_builder import SchemaBuilder
+from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
+
+from tests.integ.sagemaker.serve.constants import (
+ HF_DIR,
+ PYTHON_VERSION_IS_NOT_310,
+ SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
+)
+
+from tests.integ.timeout import timeout
+from tests.integ.utils import cleanup_model_resources
+import logging
+
+logger = logging.getLogger(__name__)
+
+sample_input = {"inputs": "What is Deep Learning?"}
+
+loaded_response = []
+
+
+@pytest.fixture
+def model_input():
+ return {"inputs": "What is Deep Learning?"}
+
+
+@pytest.fixture
+def model_builder_model_schema_builder():
+ return ModelBuilder(
+ model_path=HF_DIR,
+ model="BAAI/bge-m3",
+ schema_builder=SchemaBuilder(sample_input, loaded_response),
+ )
+
+
+@pytest.fixture
+def model_builder(request):
+ return request.getfixturevalue(request.param)
+
+
+@pytest.mark.skipif(
+ PYTHON_VERSION_IS_NOT_310,
+ reason="Testing feature needs latest metadata",
+)
+@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True)
+def test_tei_sagemaker_endpoint(sagemaker_session, model_builder, model_input):
+ logger.info("Running in SAGEMAKER_ENDPOINT mode...")
+ caught_ex = None
+
+ iam_client = sagemaker_session.boto_session.client("iam")
+ role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
+
+ model = model_builder.build(
+ mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session
+ )
+
+ with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
+ try:
+ logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
+ predictor = model.deploy(instance_type="ml.g5.2xlarge", initial_instance_count=1)
+ predictor.predict(model_input)
+ assert predictor is not None
+ except Exception as e:
+ caught_ex = e
+ finally:
+ cleanup_model_resources(
+ sagemaker_session=model_builder.sagemaker_session,
+ model_name=model.name,
+ endpoint_name=model.endpoint_name,
+ )
+ if caught_ex:
+ logger.exception(caught_ex)
+ assert False, f"{caught_ex} was thrown when running tei sagemaker endpoint test"
diff --git a/tests/integ/sagemaker/serve/test_serve_transformers.py b/tests/integ/sagemaker/serve/test_serve_transformers.py
index 64029f7290..33a1ae6708 100644
--- a/tests/integ/sagemaker/serve/test_serve_transformers.py
+++ b/tests/integ/sagemaker/serve/test_serve_transformers.py
@@ -127,4 +127,4 @@ def test_pytorch_transformers_sagemaker_endpoint(
logger.exception(caught_ex)
assert (
False
- ), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test"
+ ), f"{caught_ex} thrown when running pytorch transformers sagemaker endpoint test"
diff --git a/tests/integ/sagemaker/workflow/test_workflow.py b/tests/integ/sagemaker/workflow/test_workflow.py
index 81871d13f4..2643a3b88e 100644
--- a/tests/integ/sagemaker/workflow/test_workflow.py
+++ b/tests/integ/sagemaker/workflow/test_workflow.py
@@ -527,7 +527,8 @@ def test_one_step_ingestion_pipeline(
temp_flow_path = "./ingestion.flow"
with cleanup_feature_group(feature_group):
- json.dump(ingestion_only_flow, open(temp_flow_path, "w"))
+ with open(temp_flow_path, "w") as f:
+ json.dump(ingestion_only_flow, f)
data_wrangler_processor = DataWranglerProcessor(
role=role,
diff --git a/tests/integ/test_sagemaker_config.py b/tests/integ/test_sagemaker_config.py
index 45efb1f6ab..8cc8d50053 100644
--- a/tests/integ/test_sagemaker_config.py
+++ b/tests/integ/test_sagemaker_config.py
@@ -66,7 +66,8 @@ def expected_merged_config():
expected_merged_config_file_path = os.path.join(
CONFIG_DATA_DIR, "expected_output_config_after_merge.yaml"
)
- return yaml.safe_load(open(expected_merged_config_file_path, "r").read())
+ with open(expected_merged_config_file_path, "r") as f:
+ return yaml.safe_load(f.read())
@pytest.fixture(scope="module")
@@ -171,7 +172,8 @@ def test_config_download_from_s3_and_merge(
CONFIG_DATA_DIR, "sample_additional_config_for_merge.yaml"
)
- config_file_1_as_yaml = open(config_file_1_local_path, "r").read()
+ with open(config_file_1_local_path, "r") as f:
+ config_file_1_as_yaml = f.read()
s3_uri_config_1 = os.path.join(s3_uri_prefix, "config_1.yaml")
# Upload S3 files in case they dont already exist
diff --git a/tests/unit/sagemaker/deserializers/test_deserializers.py b/tests/unit/sagemaker/deserializers/test_deserializers.py
index b8ede11ba5..cb1923a094 100644
--- a/tests/unit/sagemaker/deserializers/test_deserializers.py
+++ b/tests/unit/sagemaker/deserializers/test_deserializers.py
@@ -142,7 +142,8 @@ def test_numpy_deserializer_from_npy(numpy_deserializer):
assert np.array_equal(array, result)
-def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
+def test_numpy_deserializer_from_npy_object_array():
+ numpy_deserializer = NumpyDeserializer(allow_pickle=True)
array = np.array([{"a": "", "b": ""}, {"c": "", "d": ""}])
stream = io.BytesIO()
np.save(stream, array)
diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py
index 582e5cf82d..fa10fd24fe 100644
--- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py
+++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py
@@ -18,6 +18,14 @@
from tests.unit.sagemaker.image_uris import expected_uris, conftest
LMI_VERSIONS = ["0.24.0"]
+TEI_VERSIONS_MAPPING = {
+ "gpu": {
+ "1.2.3": "2.0.1-tei1.2.3-gpu-py310-cu122-ubuntu22.04",
+ },
+ "cpu": {
+ "1.2.3": "2.0.1-tei1.2.3-cpu-py310-ubuntu22.04",
+ },
+}
HF_VERSIONS_MAPPING = {
"gpu": {
"0.6.0": "2.0.0-tgi0.6.0-gpu-py39-cu118-ubuntu20.04",
@@ -32,6 +40,8 @@
"1.4.2": "2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04",
"1.4.5": "2.1.1-tgi1.4.5-gpu-py310-cu121-ubuntu22.04",
"2.0.0": "2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04",
+ "2.0.1": "2.1.1-tgi2.0.1-gpu-py310-cu121-ubuntu22.04",
+ "2.0.2": "2.3.0-tgi2.0.2-gpu-py310-cu121-ubuntu22.04",
},
"inf2": {
"0.0.16": "1.13.1-optimum0.0.16-neuronx-py310-ubuntu22.04",
@@ -40,6 +50,7 @@
"0.0.19": "1.13.1-optimum0.0.19-neuronx-py310-ubuntu22.04",
"0.0.20": "1.13.1-optimum0.0.20-neuronx-py310-ubuntu22.04",
"0.0.21": "1.13.1-optimum0.0.21-neuronx-py310-ubuntu22.04",
+ "0.0.22": "2.1.2-optimum0.0.22-neuronx-py310-ubuntu22.04",
},
}
@@ -65,6 +76,28 @@ def test_huggingface_uris(load_config):
assert expected == uri
+@pytest.mark.parametrize(
+ "load_config", ["huggingface-tei.json", "huggingface-tei-cpu.json"], indirect=True
+)
+def test_huggingface_tei_uris(load_config):
+ VERSIONS = load_config["inference"]["versions"]
+ device = load_config["inference"]["processors"][0]
+ backend = "huggingface-tei" if device == "gpu" else "huggingface-tei-cpu"
+ repo = "tei" if device == "gpu" else "tei-cpu"
+ for version in VERSIONS:
+ ACCOUNTS = load_config["inference"]["versions"][version]["registries"]
+ for region in ACCOUNTS.keys():
+ uri = get_huggingface_llm_image_uri(backend, region=region, version=version)
+ expected = expected_uris.huggingface_llm_framework_uri(
+ repo,
+ ACCOUNTS[region],
+ version,
+ TEI_VERSIONS_MAPPING[device][version],
+ region=region,
+ )
+ assert expected == uri
+
+
@pytest.mark.parametrize("load_config", ["huggingface-llm.json"], indirect=True)
def test_lmi_uris(load_config):
VERSIONS = load_config["inference"]["versions"]
diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py
index 2b6856b1f3..fb7ca38bad 100644
--- a/tests/unit/sagemaker/jumpstart/constants.py
+++ b/tests/unit/sagemaker/jumpstart/constants.py
@@ -6270,6 +6270,10 @@
"framework_version": "1.5.0",
"py_version": "py3",
},
+ "default_inference_instance_type": "ml.p2.xlarge",
+ "supported_inference_instance_type": ["ml.p2.xlarge", "ml.p3.xlarge"],
+ "default_training_instance_type": "ml.p2.xlarge",
+ "supported_training_instance_type": ["ml.p2.xlarge", "ml.p3.xlarge"],
"hosting_artifact_key": "pytorch-infer/infer-pytorch-eqa-bert-base-cased.tar.gz",
"hosting_script_key": "source-directory-tarballs/pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz",
"inference_vulnerable": False,
@@ -7353,7 +7357,7 @@
"training_model_package_artifact_uris": None,
"deprecate_warn_message": None,
"deprecated_message": None,
- "hosting_model_package_arns": None,
+ "hosting_model_package_arns": {},
"hosting_eula_key": None,
"model_subscription_link": None,
"hyperparameters": [
@@ -7658,35 +7662,57 @@
"inference_configs": {
"neuron-inference": {
"benchmark_metrics": {
- "ml.inf2.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}
+ "ml.inf2.2xlarge": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ]
},
- "component_names": ["neuron-base"],
+ "component_names": ["neuron-inference"],
},
"neuron-inference-budget": {
"benchmark_metrics": {
- "ml.inf2.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}
+ "ml.inf2.2xlarge": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ]
},
"component_names": ["neuron-base"],
},
"gpu-inference-budget": {
"benchmark_metrics": {
- "ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}
+ "ml.p3.2xlarge": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ]
},
"component_names": ["gpu-inference-budget"],
},
"gpu-inference": {
"benchmark_metrics": {
- "ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}
+ "ml.p3.2xlarge": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ]
},
"component_names": ["gpu-inference"],
},
+ "gpu-inference-model-package": {
+ "benchmark_metrics": {
+ "ml.p3.2xlarge": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ]
+ },
+ "component_names": ["gpu-inference-model-package"],
+ },
},
"inference_config_components": {
"neuron-base": {
"supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"]
},
"neuron-inference": {
+ "default_inference_instance_type": "ml.inf2.xlarge",
"supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"],
+ "hosting_ecr_specs": {
+ "framework": "huggingface-llm-neuronx",
+ "framework_version": "0.0.17",
+ "py_version": "py310",
+ },
"hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/",
"hosting_instance_type_variants": {
"regional_aliases": {
@@ -7715,6 +7741,14 @@
},
},
},
+ "gpu-inference-model-package": {
+ "default_inference_instance_type": "ml.p2.xlarge",
+ "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"],
+ "hosting_model_package_arns": {
+ "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll"
+ "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
+ },
+ },
"gpu-inference-budget": {
"supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"],
"hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference-budget/model/",
@@ -7738,35 +7772,70 @@
"training_configs": {
"neuron-training": {
"benchmark_metrics": {
- "ml.tr1n1.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"},
- "ml.tr1n1.4xlarge": {"name": "Latency", "value": "50", "unit": "Tokens/S"},
+ "ml.tr1n1.2xlarge": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ],
+ "ml.tr1n1.4xlarge": [
+ {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1}
+ ],
},
"component_names": ["neuron-training"],
+ "default_inference_config": "neuron-inference",
+ "default_incremental_training_config": "neuron-training",
+ "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"],
+ "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"],
},
"neuron-training-budget": {
"benchmark_metrics": {
- "ml.tr1n1.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"},
- "ml.tr1n1.4xlarge": {"name": "Latency", "value": "50", "unit": "Tokens/S"},
+ "ml.tr1n1.2xlarge": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ],
+ "ml.tr1n1.4xlarge": [
+ {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1}
+ ],
},
"component_names": ["neuron-training-budget"],
+ "default_inference_config": "neuron-inference-budget",
+ "default_incremental_training_config": "neuron-training-budget",
+ "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"],
+ "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"],
},
"gpu-training": {
"benchmark_metrics": {
- "ml.p3.2xlarge": {"name": "Latency", "value": "200", "unit": "Tokens/S"},
+ "ml.p3.2xlarge": [
+ {"name": "Latency", "value": "200", "unit": "Tokens/S", "concurrency": "1"}
+ ],
},
"component_names": ["gpu-training"],
+ "default_inference_config": "gpu-inference",
+ "default_incremental_training_config": "gpu-training",
+ "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"],
+ "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"],
},
"gpu-training-budget": {
"benchmark_metrics": {
- "ml.p3.2xlarge": {"name": "Latency", "value": "100", "unit": "Tokens/S"}
+ "ml.p3.2xlarge": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": "1"}
+ ]
},
"component_names": ["gpu-training-budget"],
+ "default_inference_config": "gpu-inference-budget",
+ "default_incremental_training_config": "gpu-training-budget",
+ "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"],
+ "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"],
},
},
"training_config_components": {
"neuron-training": {
+ "default_training_instance_type": "ml.trn1.2xlarge",
"supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"],
"training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/",
+ "training_ecr_specs": {
+ "framework": "huggingface",
+ "framework_version": "2.0.0",
+ "py_version": "py310",
+ "huggingface_transformers_version": "4.28.1",
+ },
"training_instance_type_variants": {
"regional_aliases": {
"us-west-2": {
@@ -7778,6 +7847,7 @@
},
},
"gpu-training": {
+ "default_training_instance_type": "ml.p2.xlarge",
"supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"],
"training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/",
"training_instance_type_variants": {
@@ -7794,6 +7864,7 @@
},
},
"neuron-training-budget": {
+ "default_training_instance_type": "ml.trn1.2xlarge",
"supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"],
"training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/",
"training_instance_type_variants": {
@@ -7807,6 +7878,7 @@
},
},
"gpu-training-budget": {
+ "default_training_instance_type": "ml.p2.xlarge",
"supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"],
"training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/",
"training_instance_type_variants": {
@@ -7897,3 +7969,170 @@
},
}
}
+
+
+DEPLOYMENT_CONFIGS = [
+ {
+ "DeploymentConfigName": "neuron-inference",
+ "DeploymentArgs": {
+ "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4"
+ ".0-gpu-py310-cu121-ubuntu20.04",
+ "ModelData": {
+ "S3DataSource": {
+ "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface"
+ "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/",
+ "S3DataType": "S3Prefix",
+ "CompressionType": "None",
+ }
+ },
+ "Environment": {
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "ENDPOINT_SERVER_TIMEOUT": "3600",
+ "MODEL_CACHE_ROOT": "/opt/ml/model",
+ "SAGEMAKER_ENV": "1",
+ "HF_MODEL_ID": "/opt/ml/model",
+ "SM_NUM_GPUS": "1",
+ "MAX_INPUT_LENGTH": "2047",
+ "MAX_TOTAL_TOKENS": "2048",
+ "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
+ },
+ "InstanceType": "ml.p2.xlarge",
+ "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None},
+ "ModelDataDownloadTimeout": None,
+ "ContainerStartupHealthCheckTimeout": None,
+ },
+ "AccelerationConfigs": None,
+ "BenchmarkMetrics": [
+ {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1}
+ ],
+ },
+ {
+ "DeploymentConfigName": "neuron-inference-budget",
+ "DeploymentArgs": {
+ "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4"
+ ".0-gpu-py310-cu121-ubuntu20.04",
+ "ModelData": {
+ "S3DataSource": {
+ "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface"
+ "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/",
+ "S3DataType": "S3Prefix",
+ "CompressionType": "None",
+ }
+ },
+ "Environment": {
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "ENDPOINT_SERVER_TIMEOUT": "3600",
+ "MODEL_CACHE_ROOT": "/opt/ml/model",
+ "SAGEMAKER_ENV": "1",
+ "HF_MODEL_ID": "/opt/ml/model",
+ "SM_NUM_GPUS": "1",
+ "MAX_INPUT_LENGTH": "2047",
+ "MAX_TOTAL_TOKENS": "2048",
+ "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
+ },
+ "InstanceType": "ml.p2.xlarge",
+ "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None},
+ "ModelDataDownloadTimeout": None,
+ "ContainerStartupHealthCheckTimeout": None,
+ },
+ "AccelerationConfigs": None,
+ "BenchmarkMetrics": [
+ {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1}
+ ],
+ },
+ {
+ "DeploymentConfigName": "gpu-inference-budget",
+ "DeploymentArgs": {
+ "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4"
+ ".0-gpu-py310-cu121-ubuntu20.04",
+ "ModelData": {
+ "S3DataSource": {
+ "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface"
+ "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/",
+ "S3DataType": "S3Prefix",
+ "CompressionType": "None",
+ }
+ },
+ "Environment": {
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "ENDPOINT_SERVER_TIMEOUT": "3600",
+ "MODEL_CACHE_ROOT": "/opt/ml/model",
+ "SAGEMAKER_ENV": "1",
+ "HF_MODEL_ID": "/opt/ml/model",
+ "SM_NUM_GPUS": "1",
+ "MAX_INPUT_LENGTH": "2047",
+ "MAX_TOTAL_TOKENS": "2048",
+ "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
+ },
+ "InstanceType": "ml.p2.xlarge",
+ "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None},
+ "ModelDataDownloadTimeout": None,
+ "ContainerStartupHealthCheckTimeout": None,
+ },
+ "AccelerationConfigs": None,
+ "BenchmarkMetrics": [
+ {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1}
+ ],
+ },
+ {
+ "DeploymentConfigName": "gpu-inference",
+ "DeploymentArgs": {
+ "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4"
+ ".0-gpu-py310-cu121-ubuntu20.04",
+ "ModelData": {
+ "S3DataSource": {
+ "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface"
+ "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/",
+ "S3DataType": "S3Prefix",
+ "CompressionType": "None",
+ }
+ },
+ "Environment": {
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "ENDPOINT_SERVER_TIMEOUT": "3600",
+ "MODEL_CACHE_ROOT": "/opt/ml/model",
+ "SAGEMAKER_ENV": "1",
+ "HF_MODEL_ID": "/opt/ml/model",
+ "SM_NUM_GPUS": "1",
+ "MAX_INPUT_LENGTH": "2047",
+ "MAX_TOTAL_TOKENS": "2048",
+ "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
+ },
+ "InstanceType": "ml.p2.xlarge",
+ "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None},
+ "ModelDataDownloadTimeout": None,
+ "ContainerStartupHealthCheckTimeout": None,
+ },
+ "AccelerationConfigs": None,
+ "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}],
+ },
+]
+
+
+INIT_KWARGS = {
+ "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu"
+ "-py310-cu121-ubuntu20.04",
+ "model_data": {
+ "S3DataSource": {
+ "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface-textgeneration"
+ "-bloom-1b1/artifacts/inference-prepack/v4.0.0/",
+ "S3DataType": "S3Prefix",
+ "CompressionType": "None",
+ }
+ },
+ "instance_type": "ml.p2.xlarge",
+ "env": {
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "ENDPOINT_SERVER_TIMEOUT": "3600",
+ "MODEL_CACHE_ROOT": "/opt/ml/model",
+ "SAGEMAKER_ENV": "1",
+ "HF_MODEL_ID": "/opt/ml/model",
+ "SM_NUM_GPUS": "1",
+ "MAX_INPUT_LENGTH": "2047",
+ "MAX_TOTAL_TOKENS": "2048",
+ "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
+ },
+ "role": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628",
+ "name": "hf-textgeneration-bloom-1b1-2024-04-22-20-23-48-799",
+ "enable_network_isolation": True,
+}
diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
index ce5f15b287..2af470a13e 100644
--- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
+++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
@@ -46,6 +46,8 @@
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from tests.unit.sagemaker.jumpstart.utils import (
+ get_prototype_manifest,
+ get_prototype_spec_with_configs,
get_special_model_spec,
overwrite_dictionary,
)
@@ -680,7 +682,6 @@ def test_estimator_use_kwargs(self):
"input_mode": "File",
"output_path": "Optional[Union[str, PipelineVariable]] = None",
"output_kms_key": "Optional[Union[str, PipelineVariable]] = None",
- "base_job_name": "Optional[str] = None",
"sagemaker_session": DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
"hyperparameters": {"hyp1": "val1"},
"tags": [],
@@ -1009,12 +1010,16 @@ def test_jumpstart_estimator_attach_eula_model(
additional_kwargs={
"model_id": "gemma-model",
"model_version": "*",
+ "tolerate_vulnerable_model": True,
+ "tolerate_deprecated_model": True,
"environment": {"accept_eula": "true"},
+ "tolerate_vulnerable_model": True,
+ "tolerate_deprecated_model": True,
},
)
@mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach")
- @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job")
+ @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job")
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
@@ -1022,15 +1027,17 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case(
self,
mock_get_model_specs: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
- get_model_id_version_from_training_job: mock.Mock,
+ get_model_info_from_training_job: mock.Mock,
mock_attach: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
- get_model_id_version_from_training_job.return_value = (
+ get_model_info_from_training_job.return_value = (
"js-trainable-model-prepacked",
"1.0.0",
+ None,
+ None,
)
mock_get_model_specs.side_effect = get_special_model_spec
@@ -1041,7 +1048,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case(
training_job_name="some-training-job-name", sagemaker_session=mock_session
)
- get_model_id_version_from_training_job.assert_called_once_with(
+ get_model_info_from_training_job.assert_called_once_with(
training_job_name="some-training-job-name",
sagemaker_session=mock_session,
)
@@ -1053,11 +1060,13 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case(
additional_kwargs={
"model_id": "js-trainable-model-prepacked",
"model_version": "1.0.0",
+ "tolerate_vulnerable_model": True,
+ "tolerate_deprecated_model": True,
},
)
@mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach")
- @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job")
+ @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job")
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
@@ -1065,13 +1074,13 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case(
self,
mock_get_model_specs: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
- get_model_id_version_from_training_job: mock.Mock,
+ get_model_info_from_training_job: mock.Mock,
mock_attach: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
- get_model_id_version_from_training_job.side_effect = ValueError()
+ get_model_info_from_training_job.side_effect = ValueError()
mock_get_model_specs.side_effect = get_special_model_spec
@@ -1082,7 +1091,7 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case(
training_job_name="some-training-job-name", sagemaker_session=mock_session
)
- get_model_id_version_from_training_job.assert_called_once_with(
+ get_model_info_from_training_job.assert_called_once_with(
training_job_name="some-training-job-name",
sagemaker_session=mock_session,
)
@@ -1109,6 +1118,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
"region",
"tolerate_vulnerable_model",
"tolerate_deprecated_model",
+ "config_name",
}
assert parent_class_init_args - js_class_init_args == init_args_to_skip
@@ -1130,7 +1140,9 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
js_class_deploy = JumpStartEstimator.deploy
js_class_deploy_args = set(signature(js_class_deploy).parameters.keys())
- assert js_class_deploy_args - parent_class_deploy_args == model_class_init_args - {
+ assert js_class_deploy_args - parent_class_deploy_args - {
+ "inference_config_name"
+ } == model_class_init_args - {
"model_data",
"self",
"name",
@@ -1209,6 +1221,7 @@ def test_no_predictor_returns_default_predictor(
tolerate_deprecated_model=False,
tolerate_vulnerable_model=False,
sagemaker_session=estimator.sagemaker_session,
+ config_name=None,
)
self.assertEqual(type(predictor), Predictor)
self.assertEqual(predictor, default_predictor_with_presets)
@@ -1366,6 +1379,7 @@ def test_incremental_training_with_unsupported_model_logs_warning(
tolerate_deprecated_model=False,
tolerate_vulnerable_model=False,
sagemaker_session=sagemaker_session,
+ config_name=None,
)
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
@@ -1417,6 +1431,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning(
tolerate_deprecated_model=False,
tolerate_vulnerable_model=False,
sagemaker_session=sagemaker_session,
+ config_name=None,
)
@mock.patch("sagemaker.utils.sagemaker_timestamp")
@@ -1848,6 +1863,268 @@ def test_jumpstart_estimator_session(
assert len(s3_clients) == 1
assert list(s3_clients)[0] == session.s3_client
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.estimator.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
+ @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
+ @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_estimator_initialization_with_config_name(
+ self,
+ mock_estimator_init: mock.Mock,
+ mock_estimator_fit: mock.Mock,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ ):
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+ mock_estimator_fit.return_value = default_predictor
+
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_session.return_value = sagemaker_session
+
+ estimator = JumpStartEstimator(
+ model_id=model_id,
+ config_name="gpu-training",
+ )
+
+ mock_estimator_init.assert_called_once_with(
+ instance_type="ml.p2.xlarge",
+ instance_count=1,
+ image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3",
+ model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/"
+ "gpu-training/model/",
+ source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/"
+ "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz",
+ entry_point="transfer_learning.py",
+ hyperparameters={"epochs": "3", "adam-learning-rate": "2e-05", "batch-size": "4"},
+ role="fake role! do not use!",
+ sagemaker_session=sagemaker_session,
+ tags=[
+ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
+ {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training"},
+ ],
+ enable_network_isolation=False,
+ )
+
+ estimator.fit()
+
+ mock_estimator_fit.assert_called_once_with(wait=True)
+
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.estimator.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
+ @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
+ @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_estimator_set_config_name(
+ self,
+ mock_estimator_init: mock.Mock,
+ mock_estimator_fit: mock.Mock,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ ):
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+ mock_estimator_fit.return_value = default_predictor
+
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_session.return_value = sagemaker_session
+
+ estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training")
+
+ estimator.set_training_config(config_name="gpu-training-budget")
+
+ mock_estimator_init.assert_called_with(
+ instance_type="ml.p2.xlarge",
+ instance_count=1,
+ image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3",
+ model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/"
+ "gpu-training-budget/model/",
+ source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/"
+ "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz",
+ entry_point="transfer_learning.py",
+ hyperparameters={"epochs": "3", "adam-learning-rate": "2e-05", "batch-size": "4"},
+ role="fake role! do not use!",
+ sagemaker_session=sagemaker_session,
+ tags=[
+ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
+ {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training-budget"},
+ ],
+ enable_network_isolation=False,
+ )
+
+ estimator.fit()
+
+ mock_estimator_fit.assert_called_once_with(wait=True)
+
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.estimator.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy")
+ @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
+ @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_estimator_default_inference_config(
+ self,
+ mock_estimator_fit: mock.Mock,
+ mock_estimator_deploy: mock.Mock,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ ):
+ mock_estimator_deploy.return_value = default_predictor
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+ mock_estimator_fit.return_value = default_predictor
+
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_session.return_value = sagemaker_session
+
+ estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training")
+
+ assert estimator.config_name == "gpu-training"
+
+ estimator.deploy()
+
+ mock_estimator_deploy.assert_called_once_with(
+ instance_type="ml.p2.xlarge",
+ initial_instance_count=1,
+ image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3",
+ source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/"
+ "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz",
+ entry_point="inference.py",
+ predictor_cls=Predictor,
+ wait=True,
+ role="fake role! do not use!",
+ use_compiled_model=False,
+ enable_network_isolation=False,
+ tags=[
+ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
+ {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference"},
+ ],
+ )
+
+ @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach")
+ @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.estimator.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
+ @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
+ @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_estimator_incremental_training_config(
+ self,
+ mock_estimator_init: mock.Mock,
+ mock_estimator_fit: mock.Mock,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ mock_get_model_info_from_training_job: mock.Mock,
+ mock_attach: mock.Mock,
+ ):
+ mock_get_model_info_from_training_job.return_value = (
+ "pytorch-eqa-bert-base-cased",
+ "1.0.0",
+ None,
+ "gpu-training-budget",
+ )
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+ mock_estimator_fit.return_value = default_predictor
+
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_session.return_value = sagemaker_session
+
+ estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training")
+
+ assert estimator.config_name == "gpu-training"
+
+ JumpStartEstimator.attach(
+ training_job_name="some-training-job-name", sagemaker_session=mock_session
+ )
+
+ mock_attach.assert_called_once_with(
+ training_job_name="some-training-job-name",
+ sagemaker_session=mock_session,
+ model_channel_name="model",
+ additional_kwargs={
+ "model_id": "pytorch-eqa-bert-base-cased",
+ "model_version": "1.0.0",
+ "tolerate_vulnerable_model": True,
+ "tolerate_deprecated_model": True,
+ "config_name": "gpu-training-budget",
+ },
+ )
+
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.estimator.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy")
+ @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
+ @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
+ @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_estimator_deploy_with_config(
+ self,
+ mock_estimator_init: mock.Mock,
+ mock_estimator_fit: mock.Mock,
+ mock_estimator_deploy: mock.Mock,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ ):
+ mock_estimator_deploy.return_value = default_predictor
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+ mock_estimator_fit.return_value = default_predictor
+
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_session.return_value = sagemaker_session
+
+ estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training-budget")
+
+ assert estimator.config_name == "gpu-training-budget"
+
+ estimator.deploy()
+
+ mock_estimator_deploy.assert_called_once_with(
+ instance_type="ml.p2.xlarge",
+ initial_instance_count=1,
+ image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3",
+ source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/"
+ "pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz",
+ entry_point="inference.py",
+ predictor_cls=Predictor,
+ wait=True,
+ role="fake role! do not use!",
+ use_compiled_model=False,
+ enable_network_isolation=False,
+ tags=[
+ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
+ {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference-budget"},
+ ],
+ )
+
def test_jumpstart_estimator_requires_model_id():
with pytest.raises(ValueError):
diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py
index 8b00eb5bcd..e80ce020f7 100644
--- a/tests/unit/sagemaker/jumpstart/model/test_model.py
+++ b/tests/unit/sagemaker/jumpstart/model/test_model.py
@@ -15,6 +15,8 @@
from typing import Optional, Set
from unittest import mock
import unittest
+
+import pandas as pd
from mock import MagicMock, Mock
import pytest
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
@@ -40,12 +42,18 @@
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from tests.unit.sagemaker.jumpstart.utils import (
+ get_prototype_spec_with_configs,
get_spec_from_base_spec,
get_special_model_spec,
overwrite_dictionary,
get_special_model_spec_for_inference_component_based_endpoint,
get_prototype_manifest,
get_prototype_model_spec,
+ get_base_spec_with_prototype_configs,
+ get_mock_init_kwargs,
+ get_base_deployment_configs,
+ get_base_spec_with_prototype_configs_with_missing_benchmarks,
+ append_instance_stat_metrics,
)
import boto3
@@ -60,9 +68,11 @@
class ModelTest(unittest.TestCase):
-
mock_session_empty_config = MagicMock(sagemaker_config={})
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER")
@mock.patch("sagemaker.utils.sagemaker_timestamp")
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@@ -80,6 +90,7 @@ def test_non_prepacked(
mock_validate_model_id_and_get_type: mock.Mock,
mock_sagemaker_timestamp: mock.Mock,
mock_jumpstart_model_factory_logger: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_model_deploy.return_value = default_predictor
@@ -139,6 +150,9 @@ def test_non_prepacked(
endpoint_logging=False,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.utils.sagemaker_timestamp")
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@@ -154,6 +168,7 @@ def test_non_prepacked_inference_component_based_endpoint(
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
mock_sagemaker_timestamp: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_model_deploy.return_value = default_predictor
@@ -219,6 +234,9 @@ def test_non_prepacked_inference_component_based_endpoint(
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.utils.sagemaker_timestamp")
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@@ -234,6 +252,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
mock_sagemaker_timestamp: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_model_deploy.return_value = default_predictor
@@ -294,6 +313,9 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -307,6 +329,7 @@ def test_prepacked(
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_model_deploy.return_value = default_predictor
@@ -353,6 +376,9 @@ def test_prepacked(
endpoint_logging=False,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.model.LOGGER.warning")
@mock.patch("sagemaker.utils.sagemaker_timestamp")
@mock.patch("sagemaker.session.Session.endpoint_from_production_variants")
@@ -370,6 +396,7 @@ def test_no_compiled_model_warning_log_js_models(
mock_endpoint_from_production_variants: mock.Mock,
mock_timestamp: mock.Mock,
mock_warning: mock.Mock(),
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_timestamp.return_value = "1234"
@@ -390,6 +417,9 @@ def test_no_compiled_model_warning_log_js_models(
mock_warning.assert_not_called()
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.utils.sagemaker_timestamp")
@mock.patch("sagemaker.session.Session.endpoint_from_production_variants")
@mock.patch("sagemaker.session.Session.create_model")
@@ -405,6 +435,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model(
mock_create_model: mock.Mock,
mock_endpoint_from_production_variants: mock.Mock,
mock_timestamp: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_timestamp.return_value = "1234"
@@ -452,6 +483,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model(
],
)
+ @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@mock.patch("sagemaker.utils.sagemaker_timestamp")
@mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type")
@@ -469,7 +501,9 @@ def test_proprietary_model_endpoint(
mock_validate_model_id_and_get_type: mock.Mock,
mock_sagemaker_timestamp: mock.Mock,
mock_get_manifest: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
+ mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {}
mock_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
@@ -509,6 +543,7 @@ def test_proprietary_model_endpoint(
container_startup_health_check_timeout=600,
)
+ @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs")
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@@ -520,7 +555,9 @@ def test_deprecated(
mock_model_init: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
+ mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {}
mock_model_deploy.return_value = default_predictor
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -536,6 +573,9 @@ def test_deprecated(
JumpStartModel(model_id=model_id, tolerate_deprecated_model=True).deploy()
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@@ -547,6 +587,7 @@ def test_vulnerable(
mock_model_init: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -610,6 +651,9 @@ def test_model_use_kwargs(self):
deploy_kwargs=all_deploy_kwargs_used,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.factory.model.environment_variables.retrieve_default")
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@@ -625,6 +669,7 @@ def evaluate_model_workflow_with_kwargs(
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
mock_retrieve_environment_variables: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
init_kwargs: Optional[dict] = None,
deploy_kwargs: Optional[dict] = None,
):
@@ -715,6 +760,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
"tolerate_deprecated_model",
"instance_type",
"model_package_arn",
+ "config_name",
}
assert parent_class_init_args - js_class_init_args == init_args_to_skip
@@ -727,6 +773,9 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
assert js_class_deploy_args - parent_class_deploy_args == set()
assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@@ -735,6 +784,7 @@ def test_validate_model_id_and_get_type(
mock_validate_model_id_and_get_type: mock.Mock,
mock_init: mock.Mock,
mock_get_init_kwargs: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
JumpStartModel(model_id="valid_model_id")
@@ -743,6 +793,9 @@ def test_validate_model_id_and_get_type(
with pytest.raises(ValueError):
JumpStartModel(model_id="invalid_model_id")
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.get_default_predictor")
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@@ -758,6 +811,7 @@ def test_no_predictor_returns_default_predictor(
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
mock_get_default_predictor: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_get_default_predictor.return_value = default_predictor_with_presets
@@ -786,10 +840,14 @@ def test_no_predictor_returns_default_predictor(
tolerate_vulnerable_model=False,
sagemaker_session=model.sagemaker_session,
model_type=JumpStartModelType.OPEN_WEIGHTS,
+ config_name=None,
)
self.assertEqual(type(predictor), Predictor)
self.assertEqual(predictor, default_predictor_with_presets)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.get_default_predictor")
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@@ -805,6 +863,7 @@ def test_no_predictor_yes_async_inference_config(
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
mock_get_default_predictor: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_get_default_predictor.return_value = default_predictor_with_presets
@@ -826,6 +885,9 @@ def test_no_predictor_yes_async_inference_config(
mock_get_default_predictor.assert_not_called()
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.get_default_predictor")
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@@ -841,6 +903,7 @@ def test_yes_predictor_returns_default_predictor(
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
mock_get_default_predictor: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_get_default_predictor.return_value = default_predictor_with_presets
@@ -862,6 +925,9 @@ def test_yes_predictor_returns_default_predictor(
self.assertEqual(type(predictor), Predictor)
self.assertEqual(predictor, default_predictor)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs")
@@ -877,6 +943,7 @@ def test_model_id_not_found_refeshes_cache_inference(
mock_retrieve_kwargs: mock.Mock,
mock_model_init: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.side_effect = [False, False]
@@ -945,6 +1012,9 @@ def test_model_id_not_found_refeshes_cache_inference(
]
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
@@ -952,6 +1022,7 @@ def test_jumpstart_model_tags(
self,
mock_get_model_specs: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -981,6 +1052,9 @@ def test_jumpstart_model_tags(
[{"Key": "blah", "Value": "blahagain"}] + js_tags,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
@@ -988,6 +1062,7 @@ def test_jumpstart_model_tags_disabled(
self,
mock_get_model_specs: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -1015,6 +1090,9 @@ def test_jumpstart_model_tags_disabled(
[{"Key": "blah", "Value": "blahagain"}],
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
@@ -1022,6 +1100,7 @@ def test_jumpstart_model_package_arn(
self,
mock_get_model_specs: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -1049,6 +1128,9 @@ def test_jumpstart_model_package_arn(
self.assertIn(tag, mock_session.create_model.call_args[1]["tags"])
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
@@ -1056,6 +1138,7 @@ def test_jumpstart_model_package_arn_override(
self,
mock_get_model_specs: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -1091,6 +1174,9 @@ def test_jumpstart_model_package_arn_override(
},
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -1100,6 +1186,7 @@ def test_jumpstart_model_package_arn_unsupported_region(
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -1117,6 +1204,9 @@ def test_jumpstart_model_package_arn_unsupported_region(
"us-east-2. Please try one of the following regions: us-west-2, us-east-1."
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.utils.sagemaker_timestamp")
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@@ -1134,6 +1224,7 @@ def test_model_data_s3_prefix_override(
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
mock_sagemaker_timestamp: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_model_deploy.return_value = default_predictor
@@ -1183,6 +1274,9 @@ def test_model_data_s3_prefix_override(
'"S3DataType": "S3Prefix", "CompressionType": "None"}}',
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -1198,6 +1292,7 @@ def test_model_data_s3_prefix_model(
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_model_deploy.return_value = default_predictor
@@ -1227,6 +1322,9 @@ def test_model_data_s3_prefix_model(
mock_js_info_logger.assert_not_called()
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -1242,6 +1340,7 @@ def test_model_artifact_variant_model(
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_model_deploy.return_value = default_predictor
@@ -1292,6 +1391,9 @@ def test_model_artifact_variant_model(
enable_network_isolation=True,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -1305,6 +1407,7 @@ def test_model_registry_accept_and_response_types(
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_model_deploy.return_value = default_predictor
@@ -1324,6 +1427,9 @@ def test_model_registry_accept_and_response_types(
response_types=["application/json;verbose", "application/json"],
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.get_default_predictor")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
@@ -1337,6 +1443,7 @@ def test_jumpstart_model_session(
mock_deploy,
mock_init,
get_default_predictor,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = True
@@ -1370,6 +1477,9 @@ def test_jumpstart_model_session(
assert len(s3_clients) == 1
assert list(s3_clients)[0] == session.s3_client
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch.dict(
"sagemaker.jumpstart.cache.os.environ",
{
@@ -1388,6 +1498,7 @@ def test_model_local_mode(
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_get_manifest: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_get_model_specs.side_effect = get_prototype_model_spec
mock_get_manifest.side_effect = (
@@ -1414,6 +1525,454 @@ def test_model_local_mode(
endpoint_logging=False,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.model.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.model.Model.deploy")
+ @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_model_initialization_with_config_name(
+ self,
+ mock_model_deploy: mock.Mock,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
+ ):
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+ mock_model_deploy.return_value = default_predictor
+
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_session.return_value = sagemaker_session
+
+ model = JumpStartModel(model_id=model_id, config_name="neuron-inference")
+
+ assert model.config_name == "neuron-inference"
+
+ model.deploy()
+
+ mock_model_deploy.assert_called_once_with(
+ initial_instance_count=1,
+ instance_type="ml.inf2.xlarge",
+ tags=[
+ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
+ {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"},
+ ],
+ wait=True,
+ endpoint_logging=False,
+ )
+
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.model.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.model.Model.deploy")
+ @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_model_set_deployment_config(
+ self,
+ mock_model_deploy: mock.Mock,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
+ ):
+ mock_get_model_specs.side_effect = get_prototype_model_spec
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+ mock_model_deploy.return_value = default_predictor
+
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_session.return_value = sagemaker_session
+
+ model = JumpStartModel(model_id=model_id)
+
+ assert model.config_name is None
+
+ model.deploy()
+
+ mock_model_deploy.assert_called_once_with(
+ initial_instance_count=1,
+ instance_type="ml.p2.xlarge",
+ tags=[
+ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
+ ],
+ wait=True,
+ endpoint_logging=False,
+ )
+
+ mock_get_model_specs.reset_mock()
+ mock_model_deploy.reset_mock()
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ model.set_deployment_config("neuron-inference", "ml.inf2.2xlarge")
+
+ assert model.config_name == "neuron-inference"
+
+ model.deploy()
+
+ mock_model_deploy.assert_called_once_with(
+ initial_instance_count=1,
+ instance_type="ml.inf2.2xlarge",
+ tags=[
+ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
+ {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"},
+ ],
+ wait=True,
+ endpoint_logging=False,
+ )
+ mock_model_deploy.reset_mock()
+ model.set_deployment_config("neuron-inference", "ml.inf2.xlarge")
+
+ assert model.config_name == "neuron-inference"
+
+ model.deploy()
+
+ mock_model_deploy.assert_called_once_with(
+ initial_instance_count=1,
+ instance_type="ml.inf2.xlarge",
+ tags=[
+ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
+ {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"},
+ ],
+ wait=True,
+ endpoint_logging=False,
+ )
+
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.model.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.model.Model.deploy")
+ @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_model_set_deployment_config_model_package(
+ self,
+ mock_model_deploy: mock.Mock,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
+ ):
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+ mock_model_deploy.return_value = default_predictor
+
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_session.return_value = sagemaker_session
+
+ model = JumpStartModel(model_id=model_id)
+
+ assert model.config_name == "neuron-inference"
+
+ model.deploy()
+
+ mock_model_deploy.assert_called_once_with(
+ initial_instance_count=1,
+ instance_type="ml.inf2.xlarge",
+ tags=[
+ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
+ {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"},
+ ],
+ wait=True,
+ endpoint_logging=False,
+ )
+
+ mock_model_deploy.reset_mock()
+
+ model.set_deployment_config(
+ config_name="gpu-inference-model-package", instance_type="ml.p2.xlarge"
+ )
+
+ assert (
+ model.model_package_arn
+ == "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
+ )
+ model.deploy()
+
+ mock_model_deploy.assert_called_once_with(
+ initial_instance_count=1,
+ instance_type="ml.p2.xlarge",
+ tags=[
+ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
+ {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference-model-package"},
+ ],
+ wait=True,
+ endpoint_logging=False,
+ )
+
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.model.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.model.Model.deploy")
+ @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_model_set_deployment_config_incompatible_instance_type_or_name(
+ self,
+ mock_model_deploy: mock.Mock,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
+ ):
+ mock_get_model_specs.side_effect = get_prototype_model_spec
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+ mock_model_deploy.return_value = default_predictor
+
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_session.return_value = sagemaker_session
+
+ model = JumpStartModel(model_id=model_id)
+
+ assert model.config_name is None
+
+ model.deploy()
+
+ mock_model_deploy.assert_called_once_with(
+ initial_instance_count=1,
+ instance_type="ml.p2.xlarge",
+ tags=[
+ {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
+ ],
+ wait=True,
+ endpoint_logging=False,
+ )
+
+ mock_get_model_specs.reset_mock()
+ mock_model_deploy.reset_mock()
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ with pytest.raises(ValueError) as error:
+ model.set_deployment_config("neuron-inference", "ml.inf2.32xlarge")
+ assert (
+ "Instance type ml.inf2.32xlarge is not supported for config neuron-inference."
+ in str(error)
+ )
+
+ with pytest.raises(ValueError) as error:
+ model.set_deployment_config("neuron-inference-unknown-name", "ml.inf2.32xlarge")
+ assert "Cannot find Jumpstart config name neuron-inference-unknown-name. " in str(error)
+
+ @mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
+ @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
+ @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.model.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_model_list_deployment_configs(
+ self,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock,
+ mock_verify_model_region_and_return_specs: mock.Mock,
+ mock_get_init_kwargs: mock.Mock,
+ ):
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
+ mock_verify_model_region_and_return_specs.side_effect = (
+ lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks()
+ )
+ mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: (
+ None,
+ append_instance_stat_metrics(metrics),
+ )
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+
+ mock_session.return_value = sagemaker_session
+
+ model = JumpStartModel(model_id=model_id)
+
+ configs = model.list_deployment_configs()
+
+ self.assertEqual(configs, get_base_deployment_configs(True))
+
+ @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.model.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_model_list_deployment_configs_empty(
+ self,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ mock_verify_model_region_and_return_specs: mock.Mock,
+ ):
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_verify_model_region_and_return_specs.side_effect = (
+ lambda *args, **kwargs: get_special_model_spec(model_id="gemma-model")
+ )
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+
+ mock_session.return_value = sagemaker_session
+
+ model = JumpStartModel(model_id=model_id)
+
+ configs = model.list_deployment_configs()
+
+ self.assertTrue(len(configs) == 0)
+
+ @mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
+ @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
+ @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.model.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.model.Model.deploy")
+ @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_model_retrieve_deployment_config(
+ self,
+ mock_model_deploy: mock.Mock,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock,
+ mock_verify_model_region_and_return_specs: mock.Mock,
+ mock_get_init_kwargs: mock.Mock,
+ ):
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_verify_model_region_and_return_specs.side_effect = (
+ lambda *args, **kwargs: get_base_spec_with_prototype_configs()
+ )
+ mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: (
+ None,
+ append_instance_stat_metrics(metrics),
+ )
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+ mock_model_deploy.return_value = default_predictor
+
+ expected = get_base_deployment_configs()[0]
+ config_name = expected.get("DeploymentConfigName")
+ instance_type = expected.get("InstanceType")
+ mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(
+ model_id, config_name
+ )
+
+ mock_session.return_value = sagemaker_session
+
+ model = JumpStartModel(model_id=model_id)
+
+ model.set_deployment_config(config_name, instance_type)
+
+ self.assertEqual(model.deployment_config, expected)
+
+ mock_get_init_kwargs.reset_mock()
+ mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
+
+ @mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
+ @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
+ @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.model.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_model_display_benchmark_metrics(
+ self,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock,
+ mock_verify_model_region_and_return_specs: mock.Mock,
+ mock_get_init_kwargs: mock.Mock,
+ ):
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
+ mock_verify_model_region_and_return_specs.side_effect = (
+ lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks()
+ )
+ mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: (
+ None,
+ append_instance_stat_metrics(metrics),
+ )
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+
+ mock_session.return_value = sagemaker_session
+
+ model = JumpStartModel(model_id=model_id)
+
+ model.display_benchmark_metrics()
+ model.display_benchmark_metrics(instance_type="g5.12xlarge")
+
+ @mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
+ @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
+ @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
+ @mock.patch("sagemaker.jumpstart.factory.model.Session")
+ @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
+ def test_model_benchmark_metrics(
+ self,
+ mock_get_model_specs: mock.Mock,
+ mock_session: mock.Mock,
+ mock_get_manifest: mock.Mock,
+ mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock,
+ mock_verify_model_region_and_return_specs: mock.Mock,
+ mock_get_init_kwargs: mock.Mock,
+ ):
+ model_id, _ = "pytorch-eqa-bert-base-cased", "*"
+
+ mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
+ mock_verify_model_region_and_return_specs.side_effect = (
+ lambda *args, **kwargs: get_base_spec_with_prototype_configs()
+ )
+ mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: (
+ None,
+ append_instance_stat_metrics(metrics),
+ )
+ mock_get_model_specs.side_effect = get_prototype_spec_with_configs
+ mock_get_manifest.side_effect = (
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ )
+
+ mock_session.return_value = sagemaker_session
+
+ model = JumpStartModel(model_id=model_id)
+
+ df = model.benchmark_metrics
+
+ self.assertTrue(isinstance(df, pd.DataFrame))
+
def test_jumpstart_model_requires_model_id():
with pytest.raises(ValueError):
diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py
index 70409704e6..2be4bde7e4 100644
--- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py
+++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py
@@ -60,6 +60,9 @@ class IntelligentDefaultsModelTest(unittest.TestCase):
region = "us-west-2"
sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs")
@@ -77,6 +80,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config(
mock_retrieve_kwargs: mock.Mock,
mock_model_init: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -101,6 +105,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config(
assert "enable_network_isolation" not in mock_model_init.call_args[1]
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs")
@@ -118,6 +125,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config(
mock_retrieve_kwargs: mock.Mock,
mock_model_init: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -147,6 +155,9 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config(
override_enable_network_isolation,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs")
@@ -164,6 +175,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config(
mock_retrieve_kwargs: mock.Mock,
mock_model_init: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -193,6 +205,9 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config(
config_enable_network_isolation,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs")
@@ -210,6 +225,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config(
mock_retrieve_kwargs: mock.Mock,
mock_model_init: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -241,6 +257,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config(
override_enable_network_isolation,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs")
@@ -258,6 +277,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config(
mock_retrieve_kwargs: mock.Mock,
mock_model_init: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -287,6 +307,9 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config(
metadata_enable_network_isolation,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs")
@@ -304,6 +327,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config(
mock_retrieve_kwargs: mock.Mock,
mock_model_init: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -334,6 +358,9 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config(
override_enable_network_isolation,
)
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs")
@@ -351,6 +378,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config(
mock_retrieve_kwargs: mock.Mock,
mock_model_init: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
@@ -375,6 +403,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config(
self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role)
assert "enable_network_isolation" not in mock_model_init.call_args[1]
+ @mock.patch(
+ "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
+ )
@mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs")
@@ -392,6 +423,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config(
mock_retrieve_kwargs: mock.Mock,
mock_model_init: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
+ mock_get_jumpstart_configs: mock.Mock,
):
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py
index c00d271ef1..301afe4d53 100644
--- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py
+++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py
@@ -1,9 +1,10 @@
from __future__ import absolute_import
+
+import datetime
import json
from unittest import TestCase
from unittest.mock import Mock, patch
-import datetime
import pytest
from sagemaker.jumpstart.constants import (
@@ -17,7 +18,6 @@
get_prototype_manifest,
get_prototype_model_spec,
)
-from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart.notebook_utils import (
_generate_jumpstart_model_versions,
@@ -227,10 +227,6 @@ def test_list_jumpstart_models_simple_case(
patched_get_manifest.assert_called()
patched_get_model_specs.assert_not_called()
- @pytest.mark.skipif(
- datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
- reason="Contact JumpStart team to fix flaky test.",
- )
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
def test_list_jumpstart_models_script_filter(
@@ -240,29 +236,31 @@ def test_list_jumpstart_models_script_filter(
get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json()
)
patched_get_manifest.side_effect = (
- lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
+ lambda region, model_type, *args, **kwargs: get_prototype_manifest(region)
)
manifest_length = len(get_prototype_manifest())
vals = [True, False]
for val in vals:
- kwargs = {"filter": f"training_supported == {val}"}
+ kwargs = {"filter": And(f"training_supported == {val}", "model_type is open_weights")}
list_jumpstart_models(**kwargs)
- assert patched_read_s3_file.call_count == manifest_length
- patched_get_manifest.assert_called_once()
+ assert patched_read_s3_file.call_count == 2 * manifest_length
+ assert patched_get_manifest.call_count == 2
patched_get_manifest.reset_mock()
patched_read_s3_file.reset_mock()
- kwargs = {"filter": f"training_supported != {val}"}
+ kwargs = {"filter": And(f"training_supported != {val}", "model_type is open_weights")}
list_jumpstart_models(**kwargs)
- assert patched_read_s3_file.call_count == manifest_length
+ assert patched_read_s3_file.call_count == 2 * manifest_length
assert patched_get_manifest.call_count == 2
patched_get_manifest.reset_mock()
patched_read_s3_file.reset_mock()
-
- kwargs = {"filter": f"training_supported in {vals}", "list_versions": True}
+ kwargs = {
+ "filter": And(f"training_supported != {val}", "model_type is open_weights"),
+ "list_versions": True,
+ }
assert list_jumpstart_models(**kwargs) == [
("catboost-classification-model", "1.0.0"),
("huggingface-spc-bert-base-cased", "1.0.0"),
@@ -273,16 +271,16 @@ def test_list_jumpstart_models_script_filter(
("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"),
("xgboost-classification-model", "1.0.0"),
]
- assert patched_read_s3_file.call_count == manifest_length
+ assert patched_read_s3_file.call_count == 2 * manifest_length
assert patched_get_manifest.call_count == 2
patched_get_manifest.reset_mock()
patched_read_s3_file.reset_mock()
- kwargs = {"filter": f"training_supported not in {vals}"}
+ kwargs = {"filter": And(f"training_supported not in {vals}", "model_type is open_weights")}
models = list_jumpstart_models(**kwargs)
assert [] == models
- assert patched_read_s3_file.call_count == manifest_length
+ assert patched_read_s3_file.call_count == 2 * manifest_length
assert patched_get_manifest.call_count == 2
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@@ -519,7 +517,7 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME):
) == list_jumpstart_models(list_versions=True)
@pytest.mark.skipif(
- datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
+ datetime.datetime.now() < datetime.datetime(year=2024, month=8, day=1),
reason="Contact JumpStart team to fix flaky test.",
)
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@@ -547,12 +545,15 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
patched_read_s3_file.side_effect = vulnerable_inference_model_spec
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
- num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
assert [] == list_jumpstart_models(
- And("inference_vulnerable is false", "training_vulnerable is false")
+ And(
+ "inference_vulnerable is false",
+ "training_vulnerable is false",
+ "model_type is open_weights",
+ )
)
- assert patched_read_s3_file.call_count == num_specs + num_prop_specs
+ assert patched_read_s3_file.call_count == num_specs
assert patched_get_manifest.call_count == 2
patched_get_manifest.reset_mock()
@@ -561,10 +562,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
patched_read_s3_file.side_effect = vulnerable_training_model_spec
assert [] == list_jumpstart_models(
- And("inference_vulnerable is false", "training_vulnerable is false")
+ And(
+ "inference_vulnerable is false",
+ "training_vulnerable is false",
+ "model_type is open_weights",
+ )
)
- assert patched_read_s3_file.call_count == num_specs + num_prop_specs
+ assert patched_read_s3_file.call_count == num_specs
assert patched_get_manifest.call_count == 2
patched_get_manifest.reset_mock()
@@ -574,10 +579,6 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
assert patched_read_s3_file.call_count == 0
- @pytest.mark.skipif(
- datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
- reason="Contact JumpStart team to fix flaky test.",
- )
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
def test_list_jumpstart_models_deprecated_models(
@@ -598,10 +599,11 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
patched_read_s3_file.side_effect = deprecated_model_spec
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
- num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
- assert [] == list_jumpstart_models("deprecated equals false")
+ assert [] == list_jumpstart_models(
+ And("deprecated equals false", "model_type is open_weights")
+ )
- assert patched_read_s3_file.call_count == num_specs + num_prop_specs
+ assert patched_read_s3_file.call_count == num_specs
assert patched_get_manifest.call_count == 2
patched_get_manifest.reset_mock()
diff --git a/tests/unit/sagemaker/jumpstart/test_payload_utils.py b/tests/unit/sagemaker/jumpstart/test_payload_utils.py
index afc955e2f3..3c339c9b95 100644
--- a/tests/unit/sagemaker/jumpstart/test_payload_utils.py
+++ b/tests/unit/sagemaker/jumpstart/test_payload_utils.py
@@ -32,10 +32,36 @@ def test_construct_payload(self, patched_get_model_specs):
region = "us-west-2"
constructed_payload_body = _construct_payload(
- prompt="kobebryant",
- model_id=model_id,
- model_version="*",
- region=region,
+ prompt="kobebryant", model_id=model_id, model_version="*", region=region
+ ).body
+
+ self.assertEqual(
+ {
+ "hello": {"prompt": "kobebryant"},
+ "seed": 43,
+ },
+ constructed_payload_body,
+ )
+
+ # Unsupported model
+ self.assertIsNone(
+ _construct_payload(
+ prompt="blah",
+ model_id="default_payloads",
+ model_version="*",
+ region=region,
+ )
+ )
+
+ @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
+ def test_construct_payload_with_specific_alias(self, patched_get_model_specs):
+ patched_get_model_specs.side_effect = get_special_model_spec
+
+ model_id = "prompt-key"
+ region = "us-west-2"
+
+ constructed_payload_body = _construct_payload(
+ prompt="kobebryant", model_id=model_id, model_version="*", region=region, alias="Dog"
).body
self.assertEqual(
diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py
index 52f28f2da1..a3425a7b90 100644
--- a/tests/unit/sagemaker/jumpstart/test_predictor.py
+++ b/tests/unit/sagemaker/jumpstart/test_predictor.py
@@ -18,7 +18,7 @@
from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec
-@patch("sagemaker.predictor.get_model_id_version_from_endpoint")
+@patch("sagemaker.predictor.get_model_info_from_endpoint")
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_jumpstart_predictor_support(
@@ -52,7 +52,7 @@ def test_jumpstart_predictor_support(
assert js_predictor.accept == MIMEType.JSON
-@patch("sagemaker.predictor.get_model_id_version_from_endpoint")
+@patch("sagemaker.predictor.get_model_info_from_endpoint")
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_proprietary_predictor_support(
@@ -91,13 +91,13 @@ def test_proprietary_predictor_support(
@patch("sagemaker.predictor.Predictor")
@patch("sagemaker.predictor.get_default_predictor")
-@patch("sagemaker.predictor.get_model_id_version_from_endpoint")
+@patch("sagemaker.predictor.get_model_info_from_endpoint")
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_jumpstart_predictor_support_no_model_id_supplied_happy_case(
patched_get_model_specs,
patched_verify_model_region_and_return_specs,
- patched_get_jumpstart_model_id_version_from_endpoint,
+ patched_get_model_info_from_endpoint,
patched_get_default_predictor,
patched_predictor,
):
@@ -105,19 +105,19 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case(
patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs
patched_get_model_specs.side_effect = get_special_model_spec
- patched_get_jumpstart_model_id_version_from_endpoint.return_value = (
+ patched_get_model_info_from_endpoint.return_value = (
"predictor-specs-model",
"1.2.3",
None,
+ None,
+ None,
)
mock_session = Mock()
predictor.retrieve_default(endpoint_name="blah", sagemaker_session=mock_session)
- patched_get_jumpstart_model_id_version_from_endpoint.assert_called_once_with(
- "blah", None, mock_session
- )
+ patched_get_model_info_from_endpoint.assert_called_once_with("blah", None, mock_session)
patched_get_default_predictor.assert_called_once_with(
predictor=patched_predictor.return_value,
@@ -128,11 +128,12 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case(
tolerate_vulnerable_model=False,
sagemaker_session=mock_session,
model_type=JumpStartModelType.OPEN_WEIGHTS,
+ config_name=None,
)
@patch("sagemaker.predictor.get_default_predictor")
-@patch("sagemaker.predictor.get_model_id_version_from_endpoint")
+@patch("sagemaker.predictor.get_model_info_from_endpoint")
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_jumpstart_predictor_support_no_model_id_supplied_sad_case(
@@ -159,7 +160,8 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case(
patched_get_default_predictor.assert_not_called()
-@patch("sagemaker.predictor.get_model_id_version_from_endpoint")
+@patch("sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {})
+@patch("sagemaker.predictor.get_model_info_from_endpoint")
@patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached")
@patch("sagemaker.jumpstart.model.validate_model_id_and_get_type")
@patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
@@ -169,7 +171,8 @@ def test_jumpstart_serializable_payload_with_predictor(
patched_verify_model_region_and_return_specs,
patched_validate_model_id_and_get_type,
patched_get_object_cached,
- patched_get_model_id_version_from_endpoint,
+ patched_get_model_info_from_endpoint,
+ patched_get_jumpstart_configs,
):
patched_get_object_cached.return_value = base64.b64decode("encodedimage")
@@ -179,7 +182,7 @@ def test_jumpstart_serializable_payload_with_predictor(
patched_get_model_specs.side_effect = get_special_model_spec
model_id, model_version = "default_payloads", "*"
- patched_get_model_id_version_from_endpoint.return_value = model_id, model_version, None
+ patched_get_model_info_from_endpoint.return_value = model_id, model_version, None
js_predictor = predictor.retrieve_default(
endpoint_name="blah", model_id=model_id, model_version=model_version
diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py
index 76ad50f31c..ce06a189bd 100644
--- a/tests/unit/sagemaker/jumpstart/test_session_utils.py
+++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py
@@ -4,167 +4,202 @@
import pytest
from sagemaker.jumpstart.session_utils import (
- _get_model_id_version_from_inference_component_endpoint_with_inference_component_name,
- _get_model_id_version_from_inference_component_endpoint_without_inference_component_name,
- _get_model_id_version_from_model_based_endpoint,
- get_model_id_version_from_endpoint,
- get_model_id_version_from_training_job,
+ _get_model_info_from_inference_component_endpoint_with_inference_component_name,
+ _get_model_info_from_inference_component_endpoint_without_inference_component_name,
+ _get_model_info_from_model_based_endpoint,
+ get_model_info_from_endpoint,
+ get_model_info_from_training_job,
)
-@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn")
-def test_get_model_id_version_from_training_job_happy_case(
- mock_get_jumpstart_model_id_version_from_resource_arn,
+@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn")
+def test_get_model_info_from_training_job_happy_case(
+ mock_get_jumpstart_model_info_from_resource_arn,
):
mock_sm_session = Mock()
mock_sm_session.boto_region_name = "us-west-2"
mock_sm_session.account_id = Mock(return_value="123456789012")
- mock_get_jumpstart_model_id_version_from_resource_arn.return_value = (
+ mock_get_jumpstart_model_info_from_resource_arn.return_value = (
"model_id",
"model_version",
+ None,
+ None,
)
- retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session)
+ retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session)
- assert retval == ("model_id", "model_version")
+ assert retval == ("model_id", "model_version", None, None)
- mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with(
+ mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with(
"arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session
)
-@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn")
-def test_get_model_id_version_from_training_job_no_model_id_inferred(
- mock_get_jumpstart_model_id_version_from_resource_arn,
+@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn")
+def test_get_model_info_from_training_job_config_name(
+ mock_get_jumpstart_model_info_from_resource_arn,
):
mock_sm_session = Mock()
mock_sm_session.boto_region_name = "us-west-2"
mock_sm_session.account_id = Mock(return_value="123456789012")
- mock_get_jumpstart_model_id_version_from_resource_arn.return_value = (
+ mock_get_jumpstart_model_info_from_resource_arn.return_value = (
+ "model_id",
+ "model_version",
+ None,
+ "training_config_name",
+ )
+
+ retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session)
+
+ assert retval == ("model_id", "model_version", None, "training_config_name")
+
+ mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with(
+ "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session
+ )
+
+
+@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn")
+def test_get_model_info_from_training_job_no_model_id_inferred(
+ mock_get_jumpstart_model_info_from_resource_arn,
+):
+ mock_sm_session = Mock()
+ mock_sm_session.boto_region_name = "us-west-2"
+ mock_sm_session.account_id = Mock(return_value="123456789012")
+
+ mock_get_jumpstart_model_info_from_resource_arn.return_value = (
None,
None,
)
with pytest.raises(ValueError):
- get_model_id_version_from_training_job("blah", sagemaker_session=mock_sm_session)
+ get_model_info_from_training_job("blah", sagemaker_session=mock_sm_session)
-@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn")
-def test_get_model_id_version_from_model_based_endpoint_happy_case(
- mock_get_jumpstart_model_id_version_from_resource_arn,
+@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn")
+def test_get_model_info_from_model_based_endpoint_happy_case(
+ mock_get_jumpstart_model_info_from_resource_arn,
):
mock_sm_session = Mock()
mock_sm_session.boto_region_name = "us-west-2"
mock_sm_session.account_id = Mock(return_value="123456789012")
- mock_get_jumpstart_model_id_version_from_resource_arn.return_value = (
+ mock_get_jumpstart_model_info_from_resource_arn.return_value = (
"model_id",
"model_version",
+ None,
+ None,
)
- retval = _get_model_id_version_from_model_based_endpoint(
+ retval = _get_model_info_from_model_based_endpoint(
"bLaH", inference_component_name=None, sagemaker_session=mock_sm_session
)
- assert retval == ("model_id", "model_version")
+ assert retval == ("model_id", "model_version", None, None)
- mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with(
+ mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with(
"arn:aws:sagemaker:us-west-2:123456789012:endpoint/blah", mock_sm_session
)
-@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn")
-def test_get_model_id_version_from_model_based_endpoint_inference_component_supplied(
- mock_get_jumpstart_model_id_version_from_resource_arn,
+@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn")
+def test_get_model_info_from_model_based_endpoint_inference_component_supplied(
+ mock_get_jumpstart_model_info_from_resource_arn,
):
mock_sm_session = Mock()
mock_sm_session.boto_region_name = "us-west-2"
mock_sm_session.account_id = Mock(return_value="123456789012")
- mock_get_jumpstart_model_id_version_from_resource_arn.return_value = (
+ mock_get_jumpstart_model_info_from_resource_arn.return_value = (
"model_id",
"model_version",
+ None,
+ None,
)
with pytest.raises(ValueError):
- _get_model_id_version_from_model_based_endpoint(
+ _get_model_info_from_model_based_endpoint(
"blah", inference_component_name="some-name", sagemaker_session=mock_sm_session
)
-@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn")
-def test_get_model_id_version_from_model_based_endpoint_no_model_id_inferred(
- mock_get_jumpstart_model_id_version_from_resource_arn,
+@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn")
+def test_get_model_info_from_model_based_endpoint_no_model_id_inferred(
+ mock_get_jumpstart_model_info_from_resource_arn,
):
mock_sm_session = Mock()
mock_sm_session.boto_region_name = "us-west-2"
mock_sm_session.account_id = Mock(return_value="123456789012")
- mock_get_jumpstart_model_id_version_from_resource_arn.return_value = (
+ mock_get_jumpstart_model_info_from_resource_arn.return_value = (
+ None,
None,
None,
)
with pytest.raises(ValueError):
- _get_model_id_version_from_model_based_endpoint(
+ _get_model_info_from_model_based_endpoint(
"blah", inference_component_name="some-name", sagemaker_session=mock_sm_session
)
-@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn")
-def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_happy_case(
- mock_get_jumpstart_model_id_version_from_resource_arn,
+@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn")
+def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_happy_case(
+ mock_get_jumpstart_model_info_from_resource_arn,
):
mock_sm_session = Mock()
mock_sm_session.boto_region_name = "us-west-2"
mock_sm_session.account_id = Mock(return_value="123456789012")
- mock_get_jumpstart_model_id_version_from_resource_arn.return_value = (
+ mock_get_jumpstart_model_info_from_resource_arn.return_value = (
"model_id",
"model_version",
+ None,
+ None,
)
- retval = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
+ retval = _get_model_info_from_inference_component_endpoint_with_inference_component_name(
"bLaH", sagemaker_session=mock_sm_session
)
- assert retval == ("model_id", "model_version")
+ assert retval == ("model_id", "model_version", None, None)
- mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with(
+ mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with(
"arn:aws:sagemaker:us-west-2:123456789012:inference-component/bLaH", mock_sm_session
)
-@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn")
-def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred(
- mock_get_jumpstart_model_id_version_from_resource_arn,
+@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn")
+def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred(
+ mock_get_jumpstart_model_info_from_resource_arn,
):
mock_sm_session = Mock()
mock_sm_session.boto_region_name = "us-west-2"
mock_sm_session.account_id = Mock(return_value="123456789012")
- mock_get_jumpstart_model_id_version_from_resource_arn.return_value = (
+ mock_get_jumpstart_model_info_from_resource_arn.return_value = (
+ None,
+ None,
None,
None,
)
with pytest.raises(ValueError):
- _get_model_id_version_from_inference_component_endpoint_with_inference_component_name(
+ _get_model_info_from_inference_component_endpoint_with_inference_component_name(
"blah", sagemaker_session=mock_sm_session
)
@patch(
- "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_"
+ "sagemaker.jumpstart.session_utils._get_model_info_from_inference_"
"component_endpoint_with_inference_component_name"
)
-def test_get_model_id_version_from_inference_component_endpoint_without_inference_component_name_happy_case(
- mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name,
+def test_get_model_info_from_inference_component_endpoint_without_inference_component_name_happy_case(
+ mock_get_model_info_from_inference_component_endpoint_with_inference_component_name,
):
mock_sm_session = Mock()
- mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = (
+ mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = (
"model_id",
"model_version",
)
@@ -172,10 +207,8 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc
return_value=["icname"]
)
- retval = (
- _get_model_id_version_from_inference_component_endpoint_without_inference_component_name(
- "blahblah", mock_sm_session
- )
+ retval = _get_model_info_from_inference_component_endpoint_without_inference_component_name(
+ "blahblah", mock_sm_session
)
assert retval == ("model_id", "model_version", "icname")
@@ -185,14 +218,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc
@patch(
- "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_"
+ "sagemaker.jumpstart.session_utils._get_model_info_from_inference_"
"component_endpoint_with_inference_component_name"
)
-def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint(
- mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name,
+def test_get_model_info_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint(
+ mock_get_model_info_from_inference_component_endpoint_with_inference_component_name,
):
mock_sm_session = Mock()
- mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = (
+ mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = (
"model_id",
"model_version",
)
@@ -200,7 +233,7 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_
return_value=[]
)
with pytest.raises(ValueError):
- _get_model_id_version_from_inference_component_endpoint_without_inference_component_name(
+ _get_model_info_from_inference_component_endpoint_without_inference_component_name(
"blahblah", mock_sm_session
)
@@ -210,14 +243,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_
@patch(
- "sagemaker.jumpstart.session_utils._get_model_id"
- "_version_from_inference_component_endpoint_with_inference_component_name"
+ "sagemaker.jumpstart.session_utils._get_model"
+ "_info_from_inference_component_endpoint_with_inference_component_name"
)
def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_multiple_ics_for_endpoint(
- mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name,
+ mock_get_model_info_from_inference_component_endpoint_with_inference_component_name,
):
mock_sm_session = Mock()
- mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = (
+ mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = (
"model_id",
"model_version",
)
@@ -227,7 +260,7 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_
)
with pytest.raises(ValueError):
- _get_model_id_version_from_inference_component_endpoint_without_inference_component_name(
+ _get_model_info_from_inference_component_endpoint_without_inference_component_name(
"blahblah", mock_sm_session
)
@@ -236,67 +269,119 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_
)
-@patch("sagemaker.jumpstart.session_utils._get_model_id_version_from_model_based_endpoint")
-def test_get_model_id_version_from_endpoint_non_inference_component_endpoint(
- mock_get_model_id_version_from_model_based_endpoint,
+@patch("sagemaker.jumpstart.session_utils._get_model_info_from_model_based_endpoint")
+def test_get_model_info_from_endpoint_non_inference_component_endpoint(
+ mock_get_model_info_from_model_based_endpoint,
):
mock_sm_session = Mock()
mock_sm_session.is_inference_component_based_endpoint.return_value = False
- mock_get_model_id_version_from_model_based_endpoint.return_value = (
+ mock_get_model_info_from_model_based_endpoint.return_value = (
"model_id",
"model_version",
+ None,
+ None,
)
- retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session)
+ retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session)
- assert retval == ("model_id", "model_version", None)
- mock_get_model_id_version_from_model_based_endpoint.assert_called_once_with(
+ assert retval == ("model_id", "model_version", None, None, None)
+ mock_get_model_info_from_model_based_endpoint.assert_called_once_with(
"blah", None, mock_sm_session
)
mock_sm_session.is_inference_component_based_endpoint.assert_called_once_with("blah")
@patch(
- "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_"
+ "sagemaker.jumpstart.session_utils._get_model_info_from_inference_"
"component_endpoint_with_inference_component_name"
)
-def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_inference_component_name(
- mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name,
+def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_component_name(
+ mock_get_model_info_from_inference_component_endpoint_with_inference_component_name,
):
mock_sm_session = Mock()
mock_sm_session.is_inference_component_based_endpoint.return_value = True
- mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = (
+ mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = (
"model_id",
"model_version",
+ None,
+ None,
)
- retval = get_model_id_version_from_endpoint(
+ retval = get_model_info_from_endpoint(
"blah", inference_component_name="icname", sagemaker_session=mock_sm_session
)
- assert retval == ("model_id", "model_version", "icname")
- mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with(
+ assert retval == ("model_id", "model_version", "icname", None, None)
+ mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with(
"icname", mock_sm_session
)
mock_sm_session.is_inference_component_based_endpoint.assert_not_called()
@patch(
- "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_component_"
+ "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_"
+ "endpoint_without_inference_component_name"
+)
+def test_get_model_info_from_endpoint_inference_component_endpoint_without_inference_component_name(
+ mock_get_model_info_from_inference_component_endpoint_without_inference_component_name,
+):
+ mock_sm_session = Mock()
+ mock_sm_session.is_inference_component_based_endpoint.return_value = True
+ mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = (
+ "model_id",
+ "model_version",
+ None,
+ None,
+ "inferred-icname",
+ )
+
+ retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session)
+
+ assert retval == ("model_id", "model_version", "inferred-icname", None, None)
+ mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once()
+
+
+@patch(
+ "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_"
"endpoint_without_inference_component_name"
)
-def test_get_model_id_version_from_endpoint_inference_component_endpoint_without_inference_component_name(
- mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name,
+def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_config_name(
+ mock_get_model_info_from_inference_component_endpoint_without_inference_component_name,
):
mock_sm_session = Mock()
mock_sm_session.is_inference_component_based_endpoint.return_value = True
- mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.return_value = (
+ mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = (
"model_id",
"model_version",
+ "inference_config_name",
+ None,
+ "inferred-icname",
+ )
+
+ retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session)
+
+ assert retval == ("model_id", "model_version", "inferred-icname", "inference_config_name", None)
+ mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once()
+
+
+@patch(
+ "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_"
+ "endpoint_without_inference_component_name"
+)
+def test_get_model_info_from_endpoint_inference_component_endpoint_with_training_config_name(
+ mock_get_model_info_from_inference_component_endpoint_without_inference_component_name,
+):
+ mock_sm_session = Mock()
+ mock_sm_session.is_inference_component_based_endpoint.return_value = True
+ mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = (
+ "model_id",
+ "model_version",
+ None,
+ "training_config_name",
"inferred-icname",
)
- retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session)
+ retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session)
- assert retval == ("model_id", "model_version", "inferred-icname")
- mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once()
+ assert retval == ("model_id", "model_version", "inferred-icname", None, "training_config_name")
+ mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once()
diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py
index 3048bbc320..23fa42c09a 100644
--- a/tests/unit/sagemaker/jumpstart/test_types.py
+++ b/tests/unit/sagemaker/jumpstart/test_types.py
@@ -12,15 +12,19 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import copy
+import pytest
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.types import (
JumpStartBenchmarkStat,
JumpStartECRSpecs,
+ JumpStartEnvironmentVariable,
JumpStartHyperparameter,
JumpStartInstanceTypeVariants,
JumpStartModelSpecs,
JumpStartModelHeader,
JumpStartConfigComponent,
+ DeploymentConfigMetadata,
+ JumpStartModelInitKwargs,
)
from tests.unit.sagemaker.jumpstart.constants import (
BASE_SPEC,
@@ -28,6 +32,7 @@
INFERENCE_CONFIGS,
TRAINING_CONFIG_RANKINGS,
TRAINING_CONFIGS,
+ INIT_KWARGS,
)
INSTANCE_TYPE_VARIANT = JumpStartInstanceTypeVariants(
@@ -923,6 +928,7 @@ def test_inference_configs_parsing():
"neuron-inference",
"neuron-budget",
"gpu-inference",
+ "gpu-inference-model-package",
"gpu-inference-budget",
]
@@ -934,9 +940,9 @@ def test_inference_configs_parsing():
assert specs1.incremental_training_supported
assert specs1.hosting_ecr_specs == JumpStartECRSpecs(
{
- "framework": "pytorch",
- "framework_version": "1.5.0",
- "py_version": "py3",
+ "framework": "huggingface-llm-neuronx",
+ "framework_version": "0.0.17",
+ "py_version": "py310",
}
)
assert specs1.training_ecr_specs == JumpStartECRSpecs(
@@ -946,7 +952,10 @@ def test_inference_configs_parsing():
"py_version": "py3",
}
)
- assert specs1.hosting_artifact_key == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz"
+ assert (
+ specs1.hosting_artifact_key
+ == "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/"
+ )
assert specs1.training_artifact_key == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz"
assert (
specs1.hosting_script_key
@@ -1012,6 +1021,80 @@ def test_inference_configs_parsing():
}
),
]
+ assert specs1.inference_environment_variables == [
+ JumpStartEnvironmentVariable(
+ {
+ "name": "SAGEMAKER_PROGRAM",
+ "type": "text",
+ "default": "inference.py",
+ "scope": "container",
+ "required_for_model_class": True,
+ }
+ ),
+ JumpStartEnvironmentVariable(
+ {
+ "name": "SAGEMAKER_SUBMIT_DIRECTORY",
+ "type": "text",
+ "default": "/opt/ml/model/code",
+ "scope": "container",
+ "required_for_model_class": False,
+ }
+ ),
+ JumpStartEnvironmentVariable(
+ {
+ "name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
+ "type": "text",
+ "default": "20",
+ "scope": "container",
+ "required_for_model_class": False,
+ }
+ ),
+ JumpStartEnvironmentVariable(
+ {
+ "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
+ "type": "text",
+ "default": "3600",
+ "scope": "container",
+ "required_for_model_class": False,
+ }
+ ),
+ JumpStartEnvironmentVariable(
+ {
+ "name": "ENDPOINT_SERVER_TIMEOUT",
+ "type": "int",
+ "default": 3600,
+ "scope": "container",
+ "required_for_model_class": True,
+ }
+ ),
+ JumpStartEnvironmentVariable(
+ {
+ "name": "MODEL_CACHE_ROOT",
+ "type": "text",
+ "default": "/opt/ml/model",
+ "scope": "container",
+ "required_for_model_class": True,
+ }
+ ),
+ JumpStartEnvironmentVariable(
+ {
+ "name": "SAGEMAKER_ENV",
+ "type": "text",
+ "default": "1",
+ "scope": "container",
+ "required_for_model_class": True,
+ }
+ ),
+ JumpStartEnvironmentVariable(
+ {
+ "name": "SAGEMAKER_MODEL_SERVER_WORKERS",
+ "type": "int",
+ "default": 1,
+ "scope": "container",
+ "required_for_model_class": True,
+ }
+ ),
+ ]
# Overrided fields in top config
assert specs1.supported_inference_instance_types == ["ml.inf2.xlarge", "ml.inf2.2xlarge"]
@@ -1019,16 +1102,83 @@ def test_inference_configs_parsing():
config = specs1.inference_configs.get_top_config_from_ranking()
assert config.benchmark_metrics == {
- "ml.inf2.2xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "100", "unit": "Tokens/S"}
- )
+ "ml.inf2.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ),
+ ]
}
assert len(config.config_components) == 1
- assert config.config_components["neuron-base"] == JumpStartConfigComponent(
- "neuron-base",
- {"supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"]},
+ assert config.config_components["neuron-inference"] == JumpStartConfigComponent(
+ "neuron-inference",
+ {
+ "default_inference_instance_type": "ml.inf2.xlarge",
+ "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"],
+ "hosting_ecr_specs": {
+ "framework": "huggingface-llm-neuronx",
+ "framework_version": "0.0.17",
+ "py_version": "py310",
+ },
+ "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/",
+ "hosting_instance_type_variants": {
+ "regional_aliases": {
+ "us-west-2": {
+ "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
+ "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04"
+ }
+ },
+ "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}},
+ },
+ },
)
- assert list(config.config_components.keys()) == ["neuron-base"]
+ assert list(config.config_components.keys()) == ["neuron-inference"]
+
+ config = specs1.inference_configs.configs["gpu-inference-model-package"]
+ assert config.config_components["gpu-inference-model-package"] == JumpStartConfigComponent(
+ "gpu-inference-model-package",
+ {
+ "default_inference_instance_type": "ml.p2.xlarge",
+ "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"],
+ "hosting_model_package_arns": {
+ "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/"
+ "llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
+ },
+ },
+ )
+ assert config.resolved_config.get("inference_environment_variables") == []
+
+ spec = {
+ **BASE_SPEC,
+ **INFERENCE_CONFIGS,
+ **INFERENCE_CONFIG_RANKINGS,
+ "unrecognized-field": "blah", # New fields in base metadata fields should be ignored
+ }
+ specs1 = JumpStartModelSpecs(spec)
+
+
+def test_set_inference_configs():
+ spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS}
+ specs1 = JumpStartModelSpecs(spec)
+
+ assert list(specs1.inference_config_components.keys()) == [
+ "neuron-base",
+ "neuron-inference",
+ "neuron-budget",
+ "gpu-inference",
+ "gpu-inference-model-package",
+ "gpu-inference-budget",
+ ]
+
+ with pytest.raises(ValueError) as error:
+ specs1.set_config("invalid_name")
+ assert "Cannot find Jumpstart config name invalid_name."
+ "List of config names that is supported by the model: "
+ "['neuron-inference', 'neuron-inference-budget', "
+ "'gpu-inference-budget', 'gpu-inference', 'gpu-inference-model-package']" in str(error.value)
+
+ assert specs1.supported_inference_instance_types == ["ml.inf2.xlarge", "ml.inf2.2xlarge"]
+ specs1.set_config("gpu-inference")
+ assert specs1.supported_inference_instance_types == ["ml.p2.xlarge", "ml.p3.2xlarge"]
def test_training_configs_parsing():
@@ -1133,19 +1283,30 @@ def test_training_configs_parsing():
config = specs1.training_configs.get_top_config_from_ranking()
assert config.benchmark_metrics == {
- "ml.tr1n1.2xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "100", "unit": "Tokens/S"}
- ),
- "ml.tr1n1.4xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "50", "unit": "Tokens/S"}
- ),
+ "ml.tr1n1.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ),
+ ],
+ "ml.tr1n1.4xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1}
+ ),
+ ],
}
assert len(config.config_components) == 1
assert config.config_components["neuron-training"] == JumpStartConfigComponent(
"neuron-training",
{
+ "default_training_instance_type": "ml.trn1.2xlarge",
"supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"],
"training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/",
+ "training_ecr_specs": {
+ "framework": "huggingface",
+ "framework_version": "2.0.0",
+ "py_version": "py310",
+ "huggingface_transformers_version": "4.28.1",
+ },
"training_instance_type_variants": {
"regional_aliases": {
"us-west-2": {
@@ -1192,3 +1353,48 @@ def test_set_training_config():
specs1.training_artifact_key
== "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/"
)
+
+ with pytest.raises(ValueError) as error:
+ specs1.set_config("invalid_name", scope=JumpStartScriptScope.TRAINING)
+ assert "Cannot find Jumpstart config name invalid_name."
+ "List of config names that is supported by the model: "
+ "['neuron-training', 'neuron-training-budget', "
+ "'gpu-training-budget', 'gpu-training']" in str(error.value)
+
+ with pytest.raises(ValueError) as error:
+ specs1.set_config("invalid_name", scope="unknown scope")
+
+
+def test_deployment_config_metadata():
+ spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS}
+ specs = JumpStartModelSpecs(spec)
+ jumpstart_config = specs.inference_configs.get_top_config_from_ranking()
+
+ deployment_config_metadata = DeploymentConfigMetadata(
+ jumpstart_config.config_name,
+ jumpstart_config,
+ JumpStartModelInitKwargs(
+ model_id=specs.model_id,
+ model_data=INIT_KWARGS.get("model_data"),
+ image_uri=INIT_KWARGS.get("image_uri"),
+ instance_type=INIT_KWARGS.get("instance_type"),
+ env=INIT_KWARGS.get("env"),
+ config_name=jumpstart_config.config_name,
+ ),
+ )
+
+ json_obj = deployment_config_metadata.to_json()
+
+ assert isinstance(json_obj, dict)
+ assert json_obj["DeploymentConfigName"] == jumpstart_config.config_name
+ for key in json_obj["BenchmarkMetrics"]:
+ assert len(json_obj["BenchmarkMetrics"][key]) == len(
+ jumpstart_config.benchmark_metrics.get(key)
+ )
+ assert json_obj["AccelerationConfigs"] == jumpstart_config.resolved_config.get(
+ "acceleration_configs"
+ )
+ assert json_obj["DeploymentArgs"]["ImageUri"] == INIT_KWARGS.get("image_uri")
+ assert json_obj["DeploymentArgs"]["ModelData"] == INIT_KWARGS.get("model_data")
+ assert json_obj["DeploymentArgs"]["Environment"] == INIT_KWARGS.get("env")
+ assert json_obj["DeploymentArgs"]["InstanceType"] == INIT_KWARGS.get("instance_type")
diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py
index e7a7d522c3..a5a063c696 100644
--- a/tests/unit/sagemaker/jumpstart/test_utils.py
+++ b/tests/unit/sagemaker/jumpstart/test_utils.py
@@ -13,6 +13,8 @@
from __future__ import absolute_import
import os
from unittest import TestCase
+
+from botocore.exceptions import ClientError
from mock.mock import Mock, patch
import pytest
import boto3
@@ -49,6 +51,8 @@
get_spec_from_base_spec,
get_special_model_spec,
get_prototype_manifest,
+ get_base_deployment_configs_metadata,
+ get_base_deployment_configs,
)
from mock import MagicMock
@@ -207,16 +211,16 @@ def test_is_jumpstart_model_uri():
assert utils.is_jumpstart_model_uri(random_jumpstart_s3_uri("random_key"))
-def test_add_jumpstart_model_id_version_tags():
+def test_add_jumpstart_model_info_tags():
tags = None
model_id = "model_id"
version = "version"
+ inference_config_name = "inference_config_name"
+ training_config_name = "training_config_name"
assert [
{"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"},
{"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"},
- ] == utils.add_jumpstart_model_id_version_tags(
- tags=tags, model_id=model_id, model_version=version
- )
+ ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version)
tags = [
{"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"},
@@ -228,9 +232,7 @@ def test_add_jumpstart_model_id_version_tags():
assert [
{"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"},
{"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version_2"},
- ] == utils.add_jumpstart_model_id_version_tags(
- tags=tags, model_id=model_id, model_version=version
- )
+ ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version)
tags = [
{"Key": "random key", "Value": "random_value"},
@@ -241,9 +243,7 @@ def test_add_jumpstart_model_id_version_tags():
{"Key": "random key", "Value": "random_value"},
{"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"},
{"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"},
- ] == utils.add_jumpstart_model_id_version_tags(
- tags=tags, model_id=model_id, model_version=version
- )
+ ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version)
tags = [
{"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"},
@@ -254,9 +254,7 @@ def test_add_jumpstart_model_id_version_tags():
assert [
{"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"},
{"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"},
- ] == utils.add_jumpstart_model_id_version_tags(
- tags=tags, model_id=model_id, model_version=version
- )
+ ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version)
tags = [
{"Key": "random key", "Value": "random_value"},
@@ -265,8 +263,58 @@ def test_add_jumpstart_model_id_version_tags():
version = None
assert [
{"Key": "random key", "Value": "random_value"},
- ] == utils.add_jumpstart_model_id_version_tags(
- tags=tags, model_id=model_id, model_version=version
+ ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version)
+
+ tags = [
+ {"Key": "random key", "Value": "random_value"},
+ ]
+ model_id = "model_id"
+ version = "version"
+ assert [
+ {"Key": "random key", "Value": "random_value"},
+ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"},
+ {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"},
+ {"Key": "sagemaker-sdk:jumpstart-inference-config-name", "Value": "inference_config_name"},
+ ] == utils.add_jumpstart_model_info_tags(
+ tags=tags,
+ model_id=model_id,
+ model_version=version,
+ config_name=inference_config_name,
+ scope=JumpStartScriptScope.INFERENCE,
+ )
+
+ tags = [
+ {"Key": "random key", "Value": "random_value"},
+ ]
+ model_id = "model_id"
+ version = "version"
+ assert [
+ {"Key": "random key", "Value": "random_value"},
+ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"},
+ {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"},
+ {"Key": "sagemaker-sdk:jumpstart-training-config-name", "Value": "training_config_name"},
+ ] == utils.add_jumpstart_model_info_tags(
+ tags=tags,
+ model_id=model_id,
+ model_version=version,
+ config_name=training_config_name,
+ scope=JumpStartScriptScope.TRAINING,
+ )
+
+ tags = [
+ {"Key": "random key", "Value": "random_value"},
+ ]
+ model_id = "model_id"
+ version = "version"
+ assert [
+ {"Key": "random key", "Value": "random_value"},
+ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"},
+ {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"},
+ ] == utils.add_jumpstart_model_info_tags(
+ tags=tags,
+ model_id=model_id,
+ model_version=version,
+ config_name=training_config_name,
)
@@ -1319,10 +1367,8 @@ def test_no_model_id_no_version_found(self):
mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}]
self.assertEquals(
- utils.get_jumpstart_model_id_version_from_resource_arn(
- "some-arn", mock_sagemaker_session
- ),
- (None, None),
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ (None, None, None, None),
)
mock_list_tags.assert_called_once_with("some-arn")
@@ -1336,10 +1382,8 @@ def test_model_id_no_version_found(self):
]
self.assertEquals(
- utils.get_jumpstart_model_id_version_from_resource_arn(
- "some-arn", mock_sagemaker_session
- ),
- ("model_id", None),
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ ("model_id", None, None, None),
)
mock_list_tags.assert_called_once_with("some-arn")
@@ -1353,10 +1397,66 @@ def test_no_model_id_version_found(self):
]
self.assertEquals(
- utils.get_jumpstart_model_id_version_from_resource_arn(
- "some-arn", mock_sagemaker_session
- ),
- (None, "model_version"),
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ (None, "model_version", None, None),
+ )
+ mock_list_tags.assert_called_once_with("some-arn")
+
+ def test_no_config_name_found(self):
+ mock_list_tags = Mock()
+ mock_sagemaker_session = Mock()
+ mock_sagemaker_session.list_tags = mock_list_tags
+ mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}]
+
+ self.assertEquals(
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ (None, None, None, None),
+ )
+ mock_list_tags.assert_called_once_with("some-arn")
+
+ def test_inference_config_name_found(self):
+ mock_list_tags = Mock()
+ mock_sagemaker_session = Mock()
+ mock_sagemaker_session.list_tags = mock_list_tags
+ mock_list_tags.return_value = [
+ {"Key": "blah", "Value": "blah1"},
+ {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name"},
+ ]
+
+ self.assertEquals(
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ (None, None, "config_name", None),
+ )
+ mock_list_tags.assert_called_once_with("some-arn")
+
+ def test_training_config_name_found(self):
+ mock_list_tags = Mock()
+ mock_sagemaker_session = Mock()
+ mock_sagemaker_session.list_tags = mock_list_tags
+ mock_list_tags.return_value = [
+ {"Key": "blah", "Value": "blah1"},
+ {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "config_name"},
+ ]
+
+ self.assertEquals(
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ (None, None, None, "config_name"),
+ )
+ mock_list_tags.assert_called_once_with("some-arn")
+
+ def test_both_config_name_found(self):
+ mock_list_tags = Mock()
+ mock_sagemaker_session = Mock()
+ mock_sagemaker_session.list_tags = mock_list_tags
+ mock_list_tags.return_value = [
+ {"Key": "blah", "Value": "blah1"},
+ {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "inference_config_name"},
+ {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "training_config_name"},
+ ]
+
+ self.assertEquals(
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ (None, None, "inference_config_name", "training_config_name"),
)
mock_list_tags.assert_called_once_with("some-arn")
@@ -1371,10 +1471,8 @@ def test_model_id_version_found(self):
]
self.assertEquals(
- utils.get_jumpstart_model_id_version_from_resource_arn(
- "some-arn", mock_sagemaker_session
- ),
- ("model_id", "model_version"),
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ ("model_id", "model_version", None, None),
)
mock_list_tags.assert_called_once_with("some-arn")
@@ -1391,10 +1489,8 @@ def test_multiple_model_id_versions_found(self):
]
self.assertEquals(
- utils.get_jumpstart_model_id_version_from_resource_arn(
- "some-arn", mock_sagemaker_session
- ),
- (None, None),
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ (None, None, None, None),
)
mock_list_tags.assert_called_once_with("some-arn")
@@ -1411,10 +1507,8 @@ def test_multiple_model_id_versions_found_aliases_consistent(self):
]
self.assertEquals(
- utils.get_jumpstart_model_id_version_from_resource_arn(
- "some-arn", mock_sagemaker_session
- ),
- ("model_id_1", "model_version_1"),
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ ("model_id_1", "model_version_1", None, None),
)
mock_list_tags.assert_called_once_with("some-arn")
@@ -1431,10 +1525,26 @@ def test_multiple_model_id_versions_found_aliases_inconsistent(self):
]
self.assertEquals(
- utils.get_jumpstart_model_id_version_from_resource_arn(
- "some-arn", mock_sagemaker_session
- ),
- (None, None),
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ (None, None, None, None),
+ )
+ mock_list_tags.assert_called_once_with("some-arn")
+
+ def test_multiple_config_names_found_aliases_inconsistent(self):
+ mock_list_tags = Mock()
+ mock_sagemaker_session = Mock()
+ mock_sagemaker_session.list_tags = mock_list_tags
+ mock_list_tags.return_value = [
+ {"Key": "blah", "Value": "blah1"},
+ {"Key": JumpStartTag.MODEL_ID, "Value": "model_id_1"},
+ {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_1"},
+ {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_1"},
+ {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_2"},
+ ]
+
+ self.assertEquals(
+ utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session),
+ ("model_id_1", "model_version_1", None, None),
)
mock_list_tags.assert_called_once_with("some-arn")
@@ -1529,6 +1639,7 @@ def test_get_jumpstart_config_names_success(
"neuron-inference-budget",
"gpu-inference-budget",
"gpu-inference",
+ "gpu-inference-model-package",
]
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -1598,24 +1709,39 @@ def test_get_jumpstart_benchmark_stats_full_list(
"mock-region", "mock-model", "mock-model-version", config_names=None
) == {
"neuron-inference": {
- "ml.inf2.2xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "100", "unit": "Tokens/S"}
- )
+ "ml.inf2.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ]
},
"neuron-inference-budget": {
- "ml.inf2.2xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "100", "unit": "Tokens/S"}
- )
+ "ml.inf2.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ]
},
"gpu-inference-budget": {
- "ml.p3.2xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "100", "unit": "Tokens/S"}
- )
+ "ml.p3.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ]
},
"gpu-inference": {
- "ml.p3.2xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "100", "unit": "Tokens/S"}
- )
+ "ml.p3.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ]
+ },
+ "gpu-inference-model-package": {
+ "ml.p3.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ]
},
}
@@ -1633,14 +1759,18 @@ def test_get_jumpstart_benchmark_stats_partial_list(
config_names=["neuron-inference-budget", "gpu-inference-budget"],
) == {
"neuron-inference-budget": {
- "ml.inf2.2xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "100", "unit": "Tokens/S"}
- )
+ "ml.inf2.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ]
},
"gpu-inference-budget": {
- "ml.p3.2xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "100", "unit": "Tokens/S"}
- )
+ "ml.p3.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ]
},
}
@@ -1658,9 +1788,11 @@ def test_get_jumpstart_benchmark_stats_single_stat(
config_names=["neuron-inference-budget"],
) == {
"neuron-inference-budget": {
- "ml.inf2.2xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "100", "unit": "Tokens/S"}
- )
+ "ml.inf2.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ]
}
}
@@ -1687,6 +1819,16 @@ def test_get_jumpstart_benchmark_stats_training(
):
patched_get_model_specs.side_effect = get_base_spec_with_prototype_configs
+ print(
+ utils.get_benchmark_stats(
+ "mock-region",
+ "mock-model",
+ "mock-model-version",
+ scope=JumpStartScriptScope.TRAINING,
+ config_names=["neuron-training", "gpu-training-budget"],
+ )
+ )
+
assert utils.get_benchmark_stats(
"mock-region",
"mock-model",
@@ -1695,16 +1837,202 @@ def test_get_jumpstart_benchmark_stats_training(
config_names=["neuron-training", "gpu-training-budget"],
) == {
"neuron-training": {
- "ml.tr1n1.2xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "100", "unit": "Tokens/S"}
- ),
- "ml.tr1n1.4xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "50", "unit": "Tokens/S"}
- ),
+ "ml.tr1n1.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ],
+ "ml.tr1n1.4xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ],
},
"gpu-training-budget": {
- "ml.p3.2xlarge": JumpStartBenchmarkStat(
- {"name": "Latency", "value": "100", "unit": "Tokens/S"}
- )
+ "ml.p3.2xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": "1"}
+ )
+ ]
},
}
+
+
+def test_extract_metrics_from_deployment_configs():
+ configs = get_base_deployment_configs_metadata()
+ configs[0].benchmark_metrics = None
+ configs[2].deployment_args = None
+
+ data = utils.get_metrics_from_deployment_configs(configs)
+
+ for key in data:
+ assert len(data[key]) == (len(configs) - 2)
+
+
+@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour")
+def test_add_instance_rate_stats_to_benchmark_metrics(
+ mock_get_instance_rate_per_hour,
+):
+ mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: {
+ "name": "Instance Rate",
+ "unit": "USD/Hrs",
+ "value": "3.76",
+ }
+
+ err, out = utils.add_instance_rate_stats_to_benchmark_metrics(
+ "us-west-2",
+ {
+ "ml.p2.xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ],
+ "ml.gd4.xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ],
+ },
+ )
+
+ assert err is None
+ for key in out:
+ assert len(out[key]) == 2
+ for metric in out[key]:
+ if metric.name == "Instance Rate":
+ assert metric.to_json() == {
+ "name": "Instance Rate",
+ "unit": "USD/Hrs",
+ "value": "3.76",
+ "concurrency": None,
+ }
+
+
+def test__normalize_benchmark_metrics():
+ rate, metrics = utils._normalize_benchmark_metrics(
+ [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ),
+ JumpStartBenchmarkStat(
+ {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ),
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 2}
+ ),
+ JumpStartBenchmarkStat(
+ {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 2}
+ ),
+ JumpStartBenchmarkStat(
+ {"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76", "concurrency": None}
+ ),
+ ]
+ )
+
+ assert rate == JumpStartBenchmarkStat(
+ {"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76", "concurrency": None}
+ )
+ assert metrics == {
+ 1: [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ),
+ JumpStartBenchmarkStat(
+ {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ ),
+ ],
+ 2: [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 2}
+ ),
+ JumpStartBenchmarkStat(
+ {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 2}
+ ),
+ ],
+ }
+
+
+@pytest.mark.parametrize(
+ "name, expected",
+ [
+ ("latency", "Latency for each user (TTFT in ms)"),
+ ("throughput", "Throughput per user (token/seconds)"),
+ ],
+)
+def test__normalize_benchmark_metric_column_name(name, expected):
+ out = utils._normalize_benchmark_metric_column_name(name)
+
+ assert out == expected
+
+
+@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour")
+def test_add_instance_rate_stats_to_benchmark_metrics_client_ex(
+ mock_get_instance_rate_per_hour,
+):
+ mock_get_instance_rate_per_hour.side_effect = ClientError(
+ {
+ "Error": {
+ "Message": "is not authorized to perform: pricing:GetProducts",
+ "Code": "AccessDenied",
+ },
+ },
+ "GetProducts",
+ )
+
+ err, out = utils.add_instance_rate_stats_to_benchmark_metrics(
+ "us-west-2",
+ {
+ "ml.p2.xlarge": [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1}
+ )
+ ],
+ },
+ )
+
+ assert err["Message"] == "is not authorized to perform: pricing:GetProducts"
+ assert err["Code"] == "AccessDenied"
+ for key in out:
+ assert len(out[key]) == 1
+
+
+@pytest.mark.parametrize(
+ "stats, expected",
+ [
+ (None, True),
+ (
+ [
+ JumpStartBenchmarkStat(
+ {
+ "name": "Instance Rate",
+ "unit": "USD/Hrs",
+ "value": "3.76",
+ "concurrency": None,
+ }
+ )
+ ],
+ True,
+ ),
+ (
+ [
+ JumpStartBenchmarkStat(
+ {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": None}
+ )
+ ],
+ False,
+ ),
+ ],
+)
+def test_has_instance_rate_stat(stats, expected):
+ assert utils.has_instance_rate_stat(stats) is expected
+
+
+@pytest.mark.parametrize(
+ "data, expected",
+ [(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())],
+)
+def test_deployment_config_response_data(data, expected):
+ out = utils.deployment_config_response_data(data)
+
+ print(out)
+ assert out == expected
diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py
index e102251060..cc4ef71cee 100644
--- a/tests/unit/sagemaker/jumpstart/utils.py
+++ b/tests/unit/sagemaker/jumpstart/utils.py
@@ -12,9 +12,10 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import copy
-from typing import List
+from typing import List, Dict, Any, Optional
import boto3
+from sagemaker.compute_resource_requirements import ResourceRequirements
from sagemaker.jumpstart.cache import JumpStartModelsCache
from sagemaker.jumpstart.constants import (
JUMPSTART_DEFAULT_REGION_NAME,
@@ -27,6 +28,10 @@
JumpStartModelSpecs,
JumpStartS3FileType,
JumpStartModelHeader,
+ JumpStartModelInitKwargs,
+ DeploymentConfigMetadata,
+ JumpStartModelDeployKwargs,
+ JumpStartBenchmarkStat,
)
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart.utils import get_formatted_manifest
@@ -43,6 +48,8 @@
SPECIAL_MODEL_SPECS_DICT,
TRAINING_CONFIG_RANKINGS,
TRAINING_CONFIGS,
+ DEPLOYMENT_CONFIGS,
+ INIT_KWARGS,
)
@@ -222,6 +229,43 @@ def get_base_spec_with_prototype_configs(
return JumpStartModelSpecs(spec)
+def get_base_spec_with_prototype_configs_with_missing_benchmarks(
+ region: str = None,
+ model_id: str = None,
+ version: str = None,
+ s3_client: boto3.client = None,
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+) -> JumpStartModelSpecs:
+ spec = copy.deepcopy(BASE_SPEC)
+ copy_inference_configs = copy.deepcopy(INFERENCE_CONFIGS)
+ copy_inference_configs["inference_configs"]["neuron-inference"]["benchmark_metrics"] = None
+
+ inference_configs = {**copy_inference_configs, **INFERENCE_CONFIG_RANKINGS}
+ training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS}
+
+ spec.update(inference_configs)
+ spec.update(training_configs)
+
+ return JumpStartModelSpecs(spec)
+
+
+def get_prototype_spec_with_configs(
+ region: str = None,
+ model_id: str = None,
+ version: str = None,
+ s3_client: boto3.client = None,
+ model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
+) -> JumpStartModelSpecs:
+ spec = copy.deepcopy(PROTOTYPICAL_MODEL_SPECS_DICT[model_id])
+ inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS}
+ training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS}
+
+ spec.update(inference_configs)
+ spec.update(training_configs)
+
+ return JumpStartModelSpecs(spec)
+
+
def patched_retrieval_function(
_modelCacheObj: JumpStartModelsCache,
key: JumpStartCachedS3ContentKey,
@@ -280,3 +324,101 @@ def overwrite_dictionary(
base_dictionary[key] = value
return base_dictionary
+
+
+def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, Any]]:
+ configs = copy.deepcopy(DEPLOYMENT_CONFIGS)
+ configs[0]["AccelerationConfigs"] = [
+ {"Type": "Speculative-Decoding", "Enabled": True, "Spec": {"Version": "0.1"}}
+ ]
+ return configs
+
+
+def get_mock_init_kwargs(
+ model_id: str, config_name: Optional[str] = None
+) -> JumpStartModelInitKwargs:
+ return JumpStartModelInitKwargs(
+ model_id=model_id,
+ model_type=JumpStartModelType.OPEN_WEIGHTS,
+ model_data=INIT_KWARGS.get("model_data"),
+ image_uri=INIT_KWARGS.get("image_uri"),
+ instance_type=INIT_KWARGS.get("instance_type"),
+ env=INIT_KWARGS.get("env"),
+ resources=ResourceRequirements(),
+ config_name=config_name,
+ )
+
+
+def get_base_deployment_configs_metadata(
+ omit_benchmark_metrics: bool = False,
+) -> List[DeploymentConfigMetadata]:
+ specs = (
+ get_base_spec_with_prototype_configs_with_missing_benchmarks()
+ if omit_benchmark_metrics
+ else get_base_spec_with_prototype_configs()
+ )
+ configs = []
+ for config_name in specs.inference_configs.config_rankings.get("overall").rankings:
+ jumpstart_config = specs.inference_configs.configs.get(config_name)
+ benchmark_metrics = jumpstart_config.benchmark_metrics
+
+ if benchmark_metrics:
+ for instance_type in benchmark_metrics:
+ benchmark_metrics[instance_type].append(
+ JumpStartBenchmarkStat(
+ {
+ "name": "Instance Rate",
+ "unit": "USD/Hrs",
+ "value": "3.76",
+ "concurrency": None,
+ }
+ )
+ )
+
+ configs.append(
+ DeploymentConfigMetadata(
+ config_name=config_name,
+ metadata_config=jumpstart_config,
+ init_kwargs=get_mock_init_kwargs(
+ get_base_spec_with_prototype_configs().model_id, config_name
+ ),
+ deploy_kwargs=JumpStartModelDeployKwargs(
+ model_id=get_base_spec_with_prototype_configs().model_id,
+ ),
+ )
+ )
+ return configs
+
+
+def get_base_deployment_configs(
+ omit_benchmark_metrics: bool = False,
+) -> List[Dict[str, Any]]:
+ configs = []
+ for config in get_base_deployment_configs_metadata(omit_benchmark_metrics):
+ config_json = config.to_json()
+ if config_json["BenchmarkMetrics"]:
+ config_json["BenchmarkMetrics"] = {
+ config.deployment_args.instance_type: config_json["BenchmarkMetrics"].get(
+ config.deployment_args.instance_type
+ )
+ }
+ configs.append(config_json)
+ return configs
+
+
+def append_instance_stat_metrics(
+ metrics: Dict[str, List[JumpStartBenchmarkStat]]
+) -> Dict[str, List[JumpStartBenchmarkStat]]:
+ if metrics is not None:
+ for key in metrics:
+ metrics[key].append(
+ JumpStartBenchmarkStat(
+ {
+ "name": "Instance Rate",
+ "value": "3.76",
+ "unit": "USD/Hrs",
+ "concurrency": None,
+ }
+ )
+ )
+ return metrics
diff --git a/tests/unit/sagemaker/local/test_local_image.py b/tests/unit/sagemaker/local/test_local_image.py
index 08c55fa0b4..3142fa6dfa 100644
--- a/tests/unit/sagemaker/local/test_local_image.py
+++ b/tests/unit/sagemaker/local/test_local_image.py
@@ -221,9 +221,12 @@ def test_write_config_file(LocalSession, tmpdir):
assert os.path.exists(resource_config_file)
assert os.path.exists(input_data_config_file)
- hyperparameters_data = json.load(open(hyperparameters_file))
- resource_config_data = json.load(open(resource_config_file))
- input_data_config_data = json.load(open(input_data_config_file))
+ with open(hyperparameters_file) as f:
+ hyperparameters_data = json.load(f)
+ with open(resource_config_file) as f:
+ resource_config_data = json.load(f)
+ with open(input_data_config_file) as f:
+ input_data_config_data = json.load(f)
# Validate HyperParameters
for k, v in HYPERPARAMETERS.items():
@@ -280,7 +283,8 @@ def test_write_config_files_input_content_type(LocalSession, tmpdir):
sagemaker_container.write_config_files(host, HYPERPARAMETERS, input_data_config)
assert os.path.exists(input_data_config_file)
- parsed_input_config = json.load(open(input_data_config_file))
+ with open(input_data_config_file) as f:
+ parsed_input_config = json.load(f)
# Validate Input Data Config
for channel in input_data_config:
assert channel["ChannelName"] in parsed_input_config
diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py
index 953cbe775c..69ea2c1f56 100644
--- a/tests/unit/sagemaker/model/test_deploy.py
+++ b/tests/unit/sagemaker/model/test_deploy.py
@@ -125,6 +125,7 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem
volume_size=None,
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
+ routing_config=None,
)
sagemaker_session.create_model.assert_called_with(
@@ -184,6 +185,7 @@ def test_deploy_accelerator_type(
volume_size=None,
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
+ routing_config=None,
)
sagemaker_session.endpoint_from_production_variants.assert_called_with(
@@ -506,6 +508,7 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model,
volume_size=None,
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
+ routing_config=None,
)
sagemaker_session.endpoint_from_production_variants.assert_called_with(
@@ -938,6 +941,7 @@ def test_deploy_customized_volume_size_and_timeout(
volume_size=volume_size_gb,
model_data_download_timeout=model_data_download_timeout_sec,
container_startup_health_check_timeout=startup_health_check_timeout_sec,
+ routing_config=None,
)
sagemaker_session.create_model.assert_called_with(
@@ -987,6 +991,7 @@ def test_deploy_with_resources(sagemaker_session, name_from_base, production_var
volume_size=None,
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
+ routing_config=None,
)
sagemaker_session.endpoint_from_production_variants.assert_called_with(
name=name_from_base(MODEL_NAME),
diff --git a/tests/unit/sagemaker/serializers/test_serializers.py b/tests/unit/sagemaker/serializers/test_serializers.py
index 6b70c600ca..cb69e9e4ac 100644
--- a/tests/unit/sagemaker/serializers/test_serializers.py
+++ b/tests/unit/sagemaker/serializers/test_serializers.py
@@ -345,7 +345,8 @@ def test_data_serializer_raw(data_serializer):
input_image = image.read()
input_image_data = data_serializer.serialize(input_image)
validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw")
- validation_image_data = open(validation_image_file_path, "rb").read()
+ with open(validation_image_file_path, "rb") as f:
+ validation_image_data = f.read()
assert input_image_data == validation_image_data
@@ -353,5 +354,6 @@ def test_data_serializer_file_like(data_serializer):
input_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.jpg")
validation_image_file_path = os.path.join(DATA_DIR, "", "cuteCat.raw")
input_image_data = data_serializer.serialize(input_image_file_path)
- validation_image_data = open(validation_image_file_path, "rb").read()
+ with open(validation_image_file_path, "rb") as f:
+ validation_image_data = f.read()
assert input_image_data == validation_image_data
diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py
index 2a0c791215..e38317067c 100644
--- a/tests/unit/sagemaker/serve/builder/test_js_builder.py
+++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py
@@ -23,6 +23,7 @@
LocalModelOutOfMemoryException,
LocalModelInvocationException,
)
+from tests.unit.sagemaker.serve.constants import DEPLOYMENT_CONFIGS
mock_model_id = "huggingface-llm-amazon-falconlite"
mock_t5_model_id = "google/flan-t5-xxl"
@@ -63,6 +64,10 @@
"123456789712.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi"
"-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04"
)
+mock_invalid_image_uri = (
+ "123456789712.dkr.ecr.us-west-2.amazonaws.com/invalid"
+ "-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04"
+)
mock_djl_image_uri = (
"123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1"
)
@@ -82,6 +87,88 @@
class TestJumpStartBuilder(unittest.TestCase):
+ @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
+ return_value=True,
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
+ return_value=MagicMock(),
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
+ return_value=({"model_type": "t5", "n_head": 71}, True),
+ )
+ @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
+ )
+ def test__build_for_jumpstart_value_error(
+ self,
+ mock_get_nb_instance,
+ mock_get_ram_usage_mb,
+ mock_prepare_for_tgi,
+ mock_pre_trained_model,
+ mock_is_jumpstart_model,
+ mock_telemetry,
+ ):
+ builder = ModelBuilder(
+ model="facebook/invalid",
+ schema_builder=mock_schema_builder,
+ mode=Mode.LOCAL_CONTAINER,
+ )
+
+ mock_pre_trained_model.return_value.image_uri = mock_invalid_image_uri
+
+ self.assertRaises(
+ ValueError,
+ lambda: builder.build(),
+ )
+
+ @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
+ return_value=True,
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
+ return_value=MagicMock(),
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.prepare_mms_js_resources",
+ return_value=({"model_type": "t5", "n_head": 71}, True),
+ )
+ @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
+ )
+ def test__build_for_mms_jumpstart(
+ self,
+ mock_get_nb_instance,
+ mock_get_ram_usage_mb,
+ mock_prepare_for_mms,
+ mock_pre_trained_model,
+ mock_is_jumpstart_model,
+ mock_telemetry,
+ ):
+ builder = ModelBuilder(
+ model="facebook/galactica-mock-model-id",
+ schema_builder=mock_schema_builder,
+ mode=Mode.LOCAL_CONTAINER,
+ )
+
+ mock_pre_trained_model.return_value.image_uri = (
+ "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface"
+ "-pytorch-inference:2.1.0-transformers4.37.0-gpu-py310-cu118"
+ "-ubuntu20.04"
+ )
+
+ builder.build()
+ builder.serve_settings.telemetry_opt_out = True
+
+ mock_prepare_for_mms.assert_called()
+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
@@ -638,3 +725,239 @@ def test_js_gated_model_ex(
ValueError,
lambda: builder.build(),
)
+
+ @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
+ return_value=True,
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
+ return_value=MagicMock(),
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
+ return_value=({"model_type": "t5", "n_head": 71}, True),
+ )
+ @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
+ )
+ def test_list_deployment_configs(
+ self,
+ mock_get_nb_instance,
+ mock_get_ram_usage_mb,
+ mock_prepare_for_tgi,
+ mock_pre_trained_model,
+ mock_is_jumpstart_model,
+ mock_telemetry,
+ ):
+ builder = ModelBuilder(
+ model="facebook/galactica-mock-model-id",
+ schema_builder=mock_schema_builder,
+ )
+
+ mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
+ mock_pre_trained_model.return_value.list_deployment_configs.side_effect = (
+ lambda: DEPLOYMENT_CONFIGS
+ )
+
+ configs = builder.list_deployment_configs()
+
+ self.assertEqual(configs, DEPLOYMENT_CONFIGS)
+
+ @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
+ return_value=True,
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
+ return_value=MagicMock(),
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
+ return_value=({"model_type": "t5", "n_head": 71}, True),
+ )
+ @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
+ )
+ def test_get_deployment_config(
+ self,
+ mock_get_nb_instance,
+ mock_get_ram_usage_mb,
+ mock_prepare_for_tgi,
+ mock_pre_trained_model,
+ mock_is_jumpstart_model,
+ mock_telemetry,
+ ):
+ builder = ModelBuilder(
+ model="facebook/galactica-mock-model-id",
+ schema_builder=mock_schema_builder,
+ )
+
+ mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
+
+ expected = DEPLOYMENT_CONFIGS[0]
+ mock_pre_trained_model.return_value.deployment_config = expected
+
+ self.assertEqual(builder.get_deployment_config(), expected)
+
+ @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
+ return_value=True,
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
+ return_value=MagicMock(),
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
+ return_value=({"model_type": "t5", "n_head": 71}, True),
+ )
+ @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
+ )
+ def test_set_deployment_config(
+ self,
+ mock_get_nb_instance,
+ mock_get_ram_usage_mb,
+ mock_prepare_for_tgi,
+ mock_pre_trained_model,
+ mock_is_jumpstart_model,
+ mock_telemetry,
+ ):
+ builder = ModelBuilder(
+ model="facebook/galactica-mock-model-id",
+ schema_builder=mock_schema_builder,
+ )
+
+ mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
+
+ builder.build()
+ builder.set_deployment_config("config-1", "ml.g5.24xlarge")
+
+ mock_pre_trained_model.return_value.set_deployment_config.assert_called_with(
+ "config-1", "ml.g5.24xlarge"
+ )
+
+ @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
+ return_value=True,
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
+ return_value=MagicMock(),
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
+ return_value=({"model_type": "t5", "n_head": 71}, True),
+ )
+ @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
+ )
+ def test_set_deployment_config_ex(
+ self,
+ mock_get_nb_instance,
+ mock_get_ram_usage_mb,
+ mock_prepare_for_tgi,
+ mock_pre_trained_model,
+ mock_is_jumpstart_model,
+ mock_telemetry,
+ ):
+ mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
+
+ self.assertRaisesRegex(
+ Exception,
+ "Cannot set deployment config to an uninitialized model.",
+ lambda: ModelBuilder(
+ model="facebook/galactica-mock-model-id", schema_builder=mock_schema_builder
+ ).set_deployment_config("config-2", "ml.g5.24xlarge"),
+ )
+
+ @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
+ return_value=True,
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
+ return_value=MagicMock(),
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
+ return_value=({"model_type": "t5", "n_head": 71}, True),
+ )
+ @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
+ )
+ def test_display_benchmark_metrics(
+ self,
+ mock_get_nb_instance,
+ mock_get_ram_usage_mb,
+ mock_prepare_for_tgi,
+ mock_pre_trained_model,
+ mock_is_jumpstart_model,
+ mock_telemetry,
+ ):
+ builder = ModelBuilder(
+ model="facebook/galactica-mock-model-id",
+ schema_builder=mock_schema_builder,
+ )
+
+ mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
+ mock_pre_trained_model.return_value.list_deployment_configs.side_effect = (
+ lambda: DEPLOYMENT_CONFIGS
+ )
+
+ builder.list_deployment_configs()
+
+ builder.display_benchmark_metrics()
+
+ mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once()
+
+ @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
+ return_value=True,
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
+ return_value=MagicMock(),
+ )
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
+ return_value=({"model_type": "t5", "n_head": 71}, True),
+ )
+ @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
+ @patch(
+ "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
+ )
+ def test_display_benchmark_metrics_initial(
+ self,
+ mock_get_nb_instance,
+ mock_get_ram_usage_mb,
+ mock_prepare_for_tgi,
+ mock_pre_trained_model,
+ mock_is_jumpstart_model,
+ mock_telemetry,
+ ):
+ builder = ModelBuilder(
+ model="facebook/galactica-mock-model-id",
+ schema_builder=mock_schema_builder,
+ )
+
+ mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
+ mock_pre_trained_model.return_value.list_deployment_configs.side_effect = (
+ lambda: DEPLOYMENT_CONFIGS
+ )
+
+ builder.display_benchmark_metrics()
+
+ mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once()
diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py
index 1d199b7401..0c06b5ae8e 100644
--- a/tests/unit/sagemaker/serve/builder/test_model_builder.py
+++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py
@@ -21,6 +21,7 @@
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils import task
from sagemaker.serve.utils.exceptions import TaskNotFoundException
+from sagemaker.serve.utils.predictors import TensorflowServingLocalPredictor
from sagemaker.serve.utils.types import ModelServer
from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG
@@ -52,6 +53,7 @@
ModelServer.TORCHSERVE,
ModelServer.TRITON,
ModelServer.DJL_SERVING,
+ ModelServer.TENSORFLOW_SERVING,
}
mock_session = MagicMock()
@@ -1474,6 +1476,44 @@ def test_text_generation(
mock_build_for_tgi.assert_called_once()
+ @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei")
+ @patch("sagemaker.image_uris.retrieve")
+ @patch("sagemaker.djl_inference.model.urllib")
+ @patch("sagemaker.djl_inference.model.json")
+ @patch("sagemaker.huggingface.llm_utils.urllib")
+ @patch("sagemaker.huggingface.llm_utils.json")
+ @patch("sagemaker.model_uris.retrieve")
+ @patch("sagemaker.serve.builder.model_builder._ServeSettings")
+ def test_sentence_similarity(
+ self,
+ mock_serveSettings,
+ mock_model_uris_retrieve,
+ mock_llm_utils_json,
+ mock_llm_utils_urllib,
+ mock_model_json,
+ mock_model_urllib,
+ mock_image_uris_retrieve,
+ mock_build_for_tei,
+ ):
+ mock_setting_object = mock_serveSettings.return_value
+ mock_setting_object.role_arn = mock_role_arn
+ mock_setting_object.s3_model_data_url = mock_s3_model_data_url
+
+ mock_model_uris_retrieve.side_effect = KeyError
+ mock_llm_utils_json.load.return_value = {"pipeline_tag": "sentence-similarity"}
+ mock_llm_utils_urllib.request.Request.side_effect = Mock()
+
+ mock_model_json.load.return_value = {"some": "config"}
+ mock_model_urllib.request.Request.side_effect = Mock()
+ mock_build_for_tei.side_effect = Mock()
+
+ mock_image_uris_retrieve.return_value = "https://some-image-uri"
+
+ model_builder = ModelBuilder(model="bloom-560m", schema_builder=schema_builder)
+ model_builder.build(sagemaker_session=mock_session)
+
+ mock_build_for_tei.assert_called_once()
+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock())
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._try_fetch_gpu_info")
@patch("sagemaker.image_uris.retrieve")
@@ -1677,6 +1717,7 @@ def test_build_task_override_with_invalid_model_provided(
model_builder.build(sagemaker_session=mock_session)
@patch("os.makedirs", Mock())
+ @patch("sagemaker.serve.builder.model_builder._maintain_lineage_tracking_for_mlflow_model")
@patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
@patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve")
@patch("sagemaker.serve.builder.model_builder.save_pkl")
@@ -1705,6 +1746,7 @@ def test_build_mlflow_model_local_input_happy(
mock_save_pkl,
mock_prepare_for_torchserve,
mock_detect_fw_version,
+ mock_lineage_tracking,
):
# setup mocks
@@ -1750,6 +1792,85 @@ def test_build_mlflow_model_local_input_happy(
self.assertEqual(build_result.serve_settings, mock_setting_object)
self.assertEqual(builder.env_vars["MLFLOW_MODEL_FLAVOR"], "sklearn")
+ build_result.deploy(
+ initial_instance_count=1, instance_type=mock_instance_type, mode=Mode.SAGEMAKER_ENDPOINT
+ )
+ mock_lineage_tracking.assert_called_once()
+
+ @patch("os.makedirs", Mock())
+ @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
+ @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve")
+ @patch("sagemaker.serve.builder.model_builder.save_pkl")
+ @patch("sagemaker.serve.builder.model_builder._copy_directory_contents")
+ @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path")
+ @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata")
+ @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model")
+ @patch("sagemaker.serve.builder.model_builder._ServeSettings")
+ @patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
+ @patch("sagemaker.serve.builder.model_builder.Model")
+ @patch("builtins.open", new_callable=mock_open, read_data="data")
+ @patch("os.path.isfile", return_value=True)
+ @patch("os.path.exists")
+ def test_build_mlflow_model_local_input_happy_flavor_server_mismatch(
+ self,
+ mock_path_exists,
+ mock_is_file,
+ mock_open,
+ mock_sdk_model,
+ mock_sageMakerEndpointMode,
+ mock_serveSettings,
+ mock_detect_container,
+ mock_get_all_flavor_metadata,
+ mock_generate_mlflow_artifact_path,
+ mock_copy_directory_contents,
+ mock_save_pkl,
+ mock_prepare_for_torchserve,
+ mock_detect_fw_version,
+ ):
+ # setup mocks
+
+ mock_detect_container.return_value = mock_image_uri
+ mock_get_all_flavor_metadata.return_value = {"sklearn": "some_data"}
+ mock_generate_mlflow_artifact_path.return_value = "some_path"
+
+ mock_prepare_for_torchserve.return_value = mock_secret_key
+
+ # Mock _ServeSettings
+ mock_setting_object = mock_serveSettings.return_value
+ mock_setting_object.role_arn = mock_role_arn
+ mock_setting_object.s3_model_data_url = mock_s3_model_data_url
+
+ mock_path_exists.side_effect = lambda path: True if path == "test_path" else False
+
+ mock_mode = Mock()
+ mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: (
+ mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None
+ )
+ mock_mode.prepare.return_value = (
+ model_data,
+ ENV_VAR_PAIR,
+ )
+
+ updated_env_var = deepcopy(ENV_VARS)
+ updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "sklearn"})
+ mock_model_obj = Mock()
+ mock_sdk_model.return_value = mock_model_obj
+
+ mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent"
+
+ # run
+ builder = ModelBuilder(
+ schema_builder=schema_builder,
+ model_metadata={"MLFLOW_MODEL_PATH": MODEL_PATH},
+ model_server=ModelServer.TENSORFLOW_SERVING,
+ )
+ with self.assertRaises(ValueError):
+ builder.build(
+ Mode.SAGEMAKER_ENDPOINT,
+ mock_role_arn,
+ mock_session,
+ )
+
@patch("os.makedirs", Mock())
@patch("sagemaker.serve.builder.model_builder.S3Downloader.list")
@patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
@@ -1899,3 +2020,240 @@ def test_build_mlflow_model_s3_input_non_mlflow_case(
mock_role_arn,
mock_session,
)
+
+ @patch("os.makedirs", Mock())
+ @patch("sagemaker.serve.builder.model_builder._maintain_lineage_tracking_for_mlflow_model")
+ @patch("sagemaker.serve.builder.tf_serving_builder.prepare_for_tf_serving")
+ @patch("sagemaker.serve.builder.model_builder.S3Downloader.list")
+ @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
+ @patch("sagemaker.serve.builder.tf_serving_builder.save_pkl")
+ @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts")
+ @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path")
+ @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata")
+ @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model")
+ @patch("sagemaker.serve.builder.model_builder._ServeSettings")
+ @patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
+ @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel")
+ @patch("builtins.open", new_callable=mock_open, read_data="data")
+ @patch("os.path.exists")
+ def test_build_mlflow_model_s3_input_tensorflow_serving_happy(
+ self,
+ mock_path_exists,
+ mock_open,
+ mock_sdk_model,
+ mock_sageMakerEndpointMode,
+ mock_serveSettings,
+ mock_detect_container,
+ mock_get_all_flavor_metadata,
+ mock_generate_mlflow_artifact_path,
+ mock_download_s3_artifacts,
+ mock_save_pkl,
+ mock_detect_fw_version,
+ mock_s3_downloader,
+ mock_prepare_for_tf_serving,
+ mock_lineage_tracking,
+ ):
+ # setup mocks
+ mock_s3_downloader.return_value = ["s3://some_path/MLmodel"]
+
+ mock_detect_container.return_value = mock_image_uri
+ mock_get_all_flavor_metadata.return_value = {"tensorflow": "some_data"}
+ mock_generate_mlflow_artifact_path.return_value = "some_path"
+
+ mock_prepare_for_tf_serving.return_value = mock_secret_key
+
+ # Mock _ServeSettings
+ mock_setting_object = mock_serveSettings.return_value
+ mock_setting_object.role_arn = mock_role_arn
+ mock_setting_object.s3_model_data_url = mock_s3_model_data_url
+
+ mock_path_exists.side_effect = lambda path: True if path == "test_path" else False
+
+ mock_mode = Mock()
+ mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: (
+ mock_mode
+ if inference_spec is None and model_server == ModelServer.TENSORFLOW_SERVING
+ else None
+ )
+ mock_mode.prepare.return_value = (
+ model_data,
+ ENV_VAR_PAIR,
+ )
+
+ updated_env_var = deepcopy(ENV_VARS)
+ updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "tensorflow"})
+ mock_model_obj = Mock()
+ mock_sdk_model.return_value = mock_model_obj
+
+ mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent"
+
+ # run
+ builder = ModelBuilder(
+ schema_builder=schema_builder, model_metadata={"MLFLOW_MODEL_PATH": "s3://test_path/"}
+ )
+ build_result = builder.build(sagemaker_session=mock_session)
+ self.assertEqual(mock_model_obj, build_result)
+ self.assertEqual(build_result.mode, Mode.SAGEMAKER_ENDPOINT)
+ self.assertEqual(build_result.modes, {str(Mode.SAGEMAKER_ENDPOINT): mock_mode})
+ self.assertEqual(build_result.serve_settings, mock_setting_object)
+ self.assertEqual(builder.env_vars["MLFLOW_MODEL_FLAVOR"], "tensorflow")
+
+ build_result.deploy(
+ initial_instance_count=1, instance_type=mock_instance_type, mode=Mode.SAGEMAKER_ENDPOINT
+ )
+ mock_lineage_tracking.assert_called_once()
+
+ @patch("os.makedirs", Mock())
+ @patch("sagemaker.serve.builder.tf_serving_builder.prepare_for_tf_serving")
+ @patch("sagemaker.serve.builder.model_builder.S3Downloader.list")
+ @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
+ @patch("sagemaker.serve.builder.tf_serving_builder.save_pkl")
+ @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts")
+ @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path")
+ @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata")
+ @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model")
+ @patch("sagemaker.serve.builder.model_builder._ServeSettings")
+ @patch("sagemaker.serve.builder.model_builder.LocalContainerMode")
+ @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel")
+ @patch("builtins.open", new_callable=mock_open, read_data="data")
+ @patch("os.path.exists")
+ def test_build_mlflow_model_s3_input_tensorflow_serving_local_mode_happy(
+ self,
+ mock_path_exists,
+ mock_open,
+ mock_sdk_model,
+ mock_local_container_mode,
+ mock_serveSettings,
+ mock_detect_container,
+ mock_get_all_flavor_metadata,
+ mock_generate_mlflow_artifact_path,
+ mock_download_s3_artifacts,
+ mock_save_pkl,
+ mock_detect_fw_version,
+ mock_s3_downloader,
+ mock_prepare_for_tf_serving,
+ ):
+ # setup mocks
+ mock_s3_downloader.return_value = ["s3://some_path/MLmodel"]
+
+ mock_detect_container.return_value = mock_image_uri
+ mock_get_all_flavor_metadata.return_value = {"tensorflow": "some_data"}
+ mock_generate_mlflow_artifact_path.return_value = "some_path"
+
+ mock_prepare_for_tf_serving.return_value = mock_secret_key
+
+ # Mock _ServeSettings
+ mock_setting_object = mock_serveSettings.return_value
+ mock_setting_object.role_arn = mock_role_arn
+ mock_setting_object.s3_model_data_url = mock_s3_model_data_url
+
+ mock_path_exists.side_effect = lambda path: True if path == "test_path" else False
+
+ mock_mode = Mock()
+ mock_mode.prepare.side_effect = lambda: None
+ mock_local_container_mode.return_value = mock_mode
+ mock_mode.prepare.return_value = (
+ model_data,
+ ENV_VAR_PAIR,
+ )
+
+ updated_env_var = deepcopy(ENV_VARS)
+ updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "tensorflow"})
+ mock_model_obj = Mock()
+ mock_sdk_model.return_value = mock_model_obj
+
+ mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent"
+
+ # run
+ builder = ModelBuilder(
+ mode=Mode.LOCAL_CONTAINER,
+ schema_builder=schema_builder,
+ model_metadata={"MLFLOW_MODEL_PATH": "s3://test_path/"},
+ )
+ build_result = builder.build(sagemaker_session=mock_session)
+ self.assertEqual(mock_model_obj, build_result)
+ self.assertEqual(build_result.mode, Mode.LOCAL_CONTAINER)
+ self.assertEqual(build_result.modes, {str(Mode.LOCAL_CONTAINER): mock_mode})
+ self.assertEqual(build_result.serve_settings, mock_setting_object)
+ self.assertEqual(builder.env_vars["MLFLOW_MODEL_FLAVOR"], "tensorflow")
+
+ predictor = build_result.deploy()
+ assert isinstance(predictor, TensorflowServingLocalPredictor)
+
+ @patch("os.makedirs", Mock())
+ @patch("sagemaker.serve.builder.tf_serving_builder.prepare_for_tf_serving")
+ @patch("sagemaker.serve.builder.model_builder.S3Downloader.list")
+ @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
+ @patch("sagemaker.serve.builder.model_builder.save_pkl")
+ @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts")
+ @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path")
+ @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata")
+ @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model")
+ @patch("sagemaker.serve.builder.model_builder._ServeSettings")
+ @patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
+ @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel")
+ @patch("builtins.open", new_callable=mock_open, read_data="data")
+ @patch("os.path.exists")
+ def test_build_tensorflow_serving_non_mlflow_case(
+ self,
+ mock_path_exists,
+ mock_open,
+ mock_sdk_model,
+ mock_sageMakerEndpointMode,
+ mock_serveSettings,
+ mock_detect_container,
+ mock_get_all_flavor_metadata,
+ mock_generate_mlflow_artifact_path,
+ mock_download_s3_artifacts,
+ mock_save_pkl,
+ mock_detect_fw_version,
+ mock_s3_downloader,
+ mock_prepare_for_tf_serving,
+ ):
+ mock_s3_downloader.return_value = []
+ mock_detect_container.return_value = mock_image_uri
+ mock_get_all_flavor_metadata.return_value = {"tensorflow": "some_data"}
+ mock_generate_mlflow_artifact_path.return_value = "some_path"
+
+ mock_prepare_for_tf_serving.return_value = mock_secret_key
+
+ # Mock _ServeSettings
+ mock_setting_object = mock_serveSettings.return_value
+ mock_setting_object.role_arn = mock_role_arn
+ mock_setting_object.s3_model_data_url = mock_s3_model_data_url
+
+ mock_path_exists.side_effect = lambda path: True if path == "test_path" else False
+
+ mock_mode = Mock()
+ mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: (
+ mock_mode
+ if inference_spec is None and model_server == ModelServer.TENSORFLOW_SERVING
+ else None
+ )
+ mock_mode.prepare.return_value = (
+ model_data,
+ ENV_VAR_PAIR,
+ )
+
+ updated_env_var = deepcopy(ENV_VARS)
+ updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "tensorflow"})
+ mock_model_obj = Mock()
+ mock_sdk_model.return_value = mock_model_obj
+
+ mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent"
+
+ # run
+ builder = ModelBuilder(
+ model=mock_fw_model,
+ schema_builder=schema_builder,
+ model_server=ModelServer.TENSORFLOW_SERVING,
+ )
+
+ self.assertRaisesRegex(
+ Exception,
+ "Tensorflow Serving is currently only supported for mlflow models.",
+ builder.build,
+ Mode.SAGEMAKER_ENDPOINT,
+ mock_role_arn,
+ mock_session,
+ )
diff --git a/tests/unit/sagemaker/serve/builder/test_tei_builder.py b/tests/unit/sagemaker/serve/builder/test_tei_builder.py
new file mode 100644
index 0000000000..4a75174bfc
--- /dev/null
+++ b/tests/unit/sagemaker/serve/builder/test_tei_builder.py
@@ -0,0 +1,152 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+from unittest.mock import MagicMock, patch
+
+import unittest
+from sagemaker.serve.builder.model_builder import ModelBuilder
+from sagemaker.serve.mode.function_pointers import Mode
+from tests.unit.sagemaker.serve.constants import MOCK_VPC_CONFIG
+
+from sagemaker.serve.utils.predictors import TeiLocalModePredictor
+
+mock_model_id = "bert-base-uncased"
+mock_prompt = "The man worked as a [MASK]."
+mock_sample_input = {"inputs": mock_prompt}
+mock_sample_output = [
+ {
+ "score": 0.0974755585193634,
+ "token": 10533,
+ "token_str": "carpenter",
+ "sequence": "the man worked as a carpenter.",
+ },
+ {
+ "score": 0.052383411675691605,
+ "token": 15610,
+ "token_str": "waiter",
+ "sequence": "the man worked as a waiter.",
+ },
+ {
+ "score": 0.04962712526321411,
+ "token": 13362,
+ "token_str": "barber",
+ "sequence": "the man worked as a barber.",
+ },
+ {
+ "score": 0.0378861166536808,
+ "token": 15893,
+ "token_str": "mechanic",
+ "sequence": "the man worked as a mechanic.",
+ },
+ {
+ "score": 0.037680838257074356,
+ "token": 18968,
+ "token_str": "salesman",
+ "sequence": "the man worked as a salesman.",
+ },
+]
+mock_schema_builder = MagicMock()
+mock_schema_builder.sample_input = mock_sample_input
+mock_schema_builder.sample_output = mock_sample_output
+MOCK_IMAGE_CONFIG = (
+ "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
+ "huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0"
+)
+
+
+class TestTEIBuilder(unittest.TestCase):
+ @patch(
+ "sagemaker.serve.builder.tei_builder._get_nb_instance",
+ return_value="ml.g5.24xlarge",
+ )
+ @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None)
+ def test_build_deploy_for_tei_local_container_and_remote_container(
+ self,
+ mock_get_nb_instance,
+ mock_telemetry,
+ ):
+ builder = ModelBuilder(
+ model=mock_model_id,
+ schema_builder=mock_schema_builder,
+ mode=Mode.LOCAL_CONTAINER,
+ vpc_config=MOCK_VPC_CONFIG,
+ model_metadata={
+ "HF_TASK": "sentence-similarity",
+ },
+ )
+
+ builder._prepare_for_mode = MagicMock()
+ builder._prepare_for_mode.side_effect = None
+
+ model = builder.build()
+ builder.serve_settings.telemetry_opt_out = True
+
+ builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
+ predictor = model.deploy(model_data_download_timeout=1800)
+
+ assert model.vpc_config == MOCK_VPC_CONFIG
+ assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800"
+ assert isinstance(predictor, TeiLocalModePredictor)
+
+ assert builder.nb_instance_type == "ml.g5.24xlarge"
+
+ builder._original_deploy = MagicMock()
+ builder._prepare_for_mode.return_value = (None, {})
+ predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
+ assert "HF_MODEL_ID" in model.env
+
+ with self.assertRaises(ValueError) as _:
+ model.deploy(mode=Mode.IN_PROCESS)
+
+ @patch(
+ "sagemaker.serve.builder.tei_builder._get_nb_instance",
+ return_value="ml.g5.24xlarge",
+ )
+ @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None)
+ def test_image_uri_override(
+ self,
+ mock_get_nb_instance,
+ mock_telemetry,
+ ):
+ builder = ModelBuilder(
+ model=mock_model_id,
+ schema_builder=mock_schema_builder,
+ mode=Mode.LOCAL_CONTAINER,
+ image_uri=MOCK_IMAGE_CONFIG,
+ model_metadata={
+ "HF_TASK": "sentence-similarity",
+ },
+ )
+
+ builder._prepare_for_mode = MagicMock()
+ builder._prepare_for_mode.side_effect = None
+
+ model = builder.build()
+ builder.serve_settings.telemetry_opt_out = True
+
+ builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
+ predictor = model.deploy(model_data_download_timeout=1800)
+
+ assert builder.image_uri == MOCK_IMAGE_CONFIG
+ assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800"
+ assert isinstance(predictor, TeiLocalModePredictor)
+
+ assert builder.nb_instance_type == "ml.g5.24xlarge"
+
+ builder._original_deploy = MagicMock()
+ builder._prepare_for_mode.return_value = (None, {})
+ predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
+ assert "HF_MODEL_ID" in model.env
+
+ with self.assertRaises(ValueError) as _:
+ model.deploy(mode=Mode.IN_PROCESS)
diff --git a/tests/unit/sagemaker/serve/builder/test_tensorflow_serving_builder.py b/tests/unit/sagemaker/serve/builder/test_tensorflow_serving_builder.py
new file mode 100644
index 0000000000..9d51b04e08
--- /dev/null
+++ b/tests/unit/sagemaker/serve/builder/test_tensorflow_serving_builder.py
@@ -0,0 +1,75 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+from unittest.mock import MagicMock, patch
+
+import unittest
+from pathlib import Path
+
+from sagemaker.serve import ModelBuilder, ModelServer
+
+
+class TestTransformersBuilder(unittest.TestCase):
+ def setUp(self):
+ self.instance = ModelBuilder()
+ self.instance.model_server = ModelServer.TENSORFLOW_SERVING
+ self.instance.model_path = "/fake/model/path"
+ self.instance.image_uri = "fake_image_uri"
+ self.instance.s3_upload_path = "s3://bucket/path"
+ self.instance.serve_settings = MagicMock(role_arn="fake_role_arn")
+ self.instance.schema_builder = MagicMock()
+ self.instance.env_vars = {}
+ self.instance.sagemaker_session = MagicMock()
+ self.instance.image_config = {}
+ self.instance.vpc_config = {}
+ self.instance.modes = {}
+
+ @patch("os.makedirs")
+ @patch("os.path.exists")
+ @patch("sagemaker.serve.builder.tf_serving_builder.save_pkl")
+ def test_save_schema_builder(self, mock_save_pkl, mock_exists, mock_makedirs):
+ mock_exists.return_value = False
+ self.instance._save_schema_builder()
+ mock_makedirs.assert_called_once_with(self.instance.model_path)
+ code_path = Path(self.instance.model_path).joinpath("code")
+ mock_save_pkl.assert_called_once_with(code_path, self.instance.schema_builder)
+
+ @patch("sagemaker.serve.builder.tf_serving_builder.TensorflowServing._get_client_translators")
+ @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowPredictor")
+ def test_get_tensorflow_predictor(self, mock_predictor, mock_get_marshaller):
+ endpoint_name = "test_endpoint"
+ predictor = self.instance._get_tensorflow_predictor(
+ endpoint_name, self.instance.sagemaker_session
+ )
+ mock_predictor.assert_called_once_with(
+ endpoint_name=endpoint_name,
+ sagemaker_session=self.instance.sagemaker_session,
+ serializer=self.instance.schema_builder.custom_input_translator,
+ deserializer=self.instance.schema_builder.custom_output_translator,
+ )
+ self.assertEqual(predictor, mock_predictor.return_value)
+
+ @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel")
+ def test_create_tensorflow_model(self, mock_model):
+ model = self.instance._create_tensorflow_model()
+ mock_model.assert_called_once_with(
+ image_uri=self.instance.image_uri,
+ image_config=self.instance.image_config,
+ vpc_config=self.instance.vpc_config,
+ model_data=self.instance.s3_upload_path,
+ role=self.instance.serve_settings.role_arn,
+ env=self.instance.env_vars,
+ sagemaker_session=self.instance.sagemaker_session,
+ predictor_cls=self.instance._get_tensorflow_predictor,
+ )
+ self.assertEqual(model, mock_model.return_value)
diff --git a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py
index e17364f22d..9ea797adc2 100644
--- a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py
+++ b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py
@@ -58,6 +58,10 @@
mock_schema_builder = MagicMock()
mock_schema_builder.sample_input = mock_sample_input
mock_schema_builder.sample_output = mock_sample_output
+MOCK_IMAGE_CONFIG = (
+ "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
+ "huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0"
+)
class TestTransformersBuilder(unittest.TestCase):
@@ -100,3 +104,69 @@ def test_build_deploy_for_transformers_local_container_and_remote_container(
with self.assertRaises(ValueError) as _:
model.deploy(mode=Mode.IN_PROCESS)
+
+ @patch(
+ "sagemaker.serve.builder.transformers_builder._get_nb_instance",
+ return_value="ml.g5.24xlarge",
+ )
+ @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
+ def test_image_uri_override(
+ self,
+ mock_get_nb_instance,
+ mock_telemetry,
+ ):
+ builder = ModelBuilder(
+ model=mock_model_id,
+ schema_builder=mock_schema_builder,
+ mode=Mode.LOCAL_CONTAINER,
+ image_uri=MOCK_IMAGE_CONFIG,
+ )
+
+ builder._prepare_for_mode = MagicMock()
+ builder._prepare_for_mode.side_effect = None
+
+ model = builder.build()
+ builder.serve_settings.telemetry_opt_out = True
+
+ builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock()
+ predictor = model.deploy(model_data_download_timeout=1800)
+
+ assert builder.image_uri == MOCK_IMAGE_CONFIG
+ assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800"
+ assert isinstance(predictor, TransformersLocalModePredictor)
+
+ assert builder.nb_instance_type == "ml.g5.24xlarge"
+
+ builder._original_deploy = MagicMock()
+ builder._prepare_for_mode.return_value = (None, {})
+ predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn")
+ assert "HF_MODEL_ID" in model.env
+
+ with self.assertRaises(ValueError) as _:
+ model.deploy(mode=Mode.IN_PROCESS)
+
+ @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers")
+ @patch(
+ "sagemaker.serve.builder.transformers_builder._get_nb_instance",
+ return_value="ml.g5.24xlarge",
+ )
+ @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
+ @patch(
+ "sagemaker.huggingface.llm_utils.get_huggingface_model_metadata",
+ return_value=None,
+ )
+ def test_failure_hf_md(
+ self, mock_model_md, mock_get_nb_instance, mock_telemetry, mock_build_for_transformers
+ ):
+ builder = ModelBuilder(
+ model=mock_model_id,
+ schema_builder=mock_schema_builder,
+ mode=Mode.LOCAL_CONTAINER,
+ )
+
+ builder._prepare_for_mode = MagicMock()
+ builder._prepare_for_mode.side_effect = None
+
+ builder.build()
+
+ mock_build_for_transformers.assert_called_once()
diff --git a/tests/unit/sagemaker/serve/constants.py b/tests/unit/sagemaker/serve/constants.py
index db9dd623d8..5c40c1bf64 100644
--- a/tests/unit/sagemaker/serve/constants.py
+++ b/tests/unit/sagemaker/serve/constants.py
@@ -15,3 +15,153 @@
MOCK_IMAGE_CONFIG = {"RepositoryAccessMode": "Vpc"}
MOCK_VPC_CONFIG = {"Subnets": ["subnet-1234"], "SecurityGroupIds": ["sg123"]}
+DEPLOYMENT_CONFIGS = [
+ {
+ "ConfigName": "neuron-inference",
+ "BenchmarkMetrics": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S"},
+ {"name": "Throughput", "value": "1867", "unit": "Tokens/S"},
+ ],
+ "DeploymentConfig": {
+ "ModelDataDownloadTimeout": 1200,
+ "ContainerStartupHealthCheckTimeout": 1200,
+ "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4"
+ ".0-gpu-py310-cu121-ubuntu20.04",
+ "ModelData": {
+ "S3DataSource": {
+ "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration"
+ "-llama-2-7b/artifacts/inference-prepack/v1.0.0/",
+ "S3DataType": "S3Prefix",
+ "CompressionType": "None",
+ }
+ },
+ "InstanceType": "ml.p2.xlarge",
+ "Environment": {
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "ENDPOINT_SERVER_TIMEOUT": "3600",
+ "MODEL_CACHE_ROOT": "/opt/ml/model",
+ "SAGEMAKER_ENV": "1",
+ "HF_MODEL_ID": "/opt/ml/model",
+ "MAX_INPUT_LENGTH": "4095",
+ "MAX_TOTAL_TOKENS": "4096",
+ "SM_NUM_GPUS": "1",
+ "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
+ },
+ "ComputeResourceRequirements": {
+ "MinMemoryRequiredInMb": 16384,
+ "NumberOfAcceleratorDevicesRequired": 1,
+ },
+ },
+ },
+ {
+ "ConfigName": "neuron-inference-budget",
+ "BenchmarkMetrics": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S"},
+ {"name": "Throughput", "value": "1867", "unit": "Tokens/S"},
+ ],
+ "DeploymentConfig": {
+ "ModelDataDownloadTimeout": 1200,
+ "ContainerStartupHealthCheckTimeout": 1200,
+ "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4"
+ ".0-gpu-py310-cu121-ubuntu20.04",
+ "ModelData": {
+ "S3DataSource": {
+ "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration"
+ "-llama-2-7b/artifacts/inference-prepack/v1.0.0/",
+ "S3DataType": "S3Prefix",
+ "CompressionType": "None",
+ }
+ },
+ "InstanceType": "ml.p2.xlarge",
+ "Environment": {
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "ENDPOINT_SERVER_TIMEOUT": "3600",
+ "MODEL_CACHE_ROOT": "/opt/ml/model",
+ "SAGEMAKER_ENV": "1",
+ "HF_MODEL_ID": "/opt/ml/model",
+ "MAX_INPUT_LENGTH": "4095",
+ "MAX_TOTAL_TOKENS": "4096",
+ "SM_NUM_GPUS": "1",
+ "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
+ },
+ "ComputeResourceRequirements": {
+ "MinMemoryRequiredInMb": 16384,
+ "NumberOfAcceleratorDevicesRequired": 1,
+ },
+ },
+ },
+ {
+ "ConfigName": "gpu-inference-budget",
+ "BenchmarkMetrics": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S"},
+ {"name": "Throughput", "value": "1867", "unit": "Tokens/S"},
+ ],
+ "DeploymentConfig": {
+ "ModelDataDownloadTimeout": 1200,
+ "ContainerStartupHealthCheckTimeout": 1200,
+ "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4"
+ ".0-gpu-py310-cu121-ubuntu20.04",
+ "ModelData": {
+ "S3DataSource": {
+ "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration"
+ "-llama-2-7b/artifacts/inference-prepack/v1.0.0/",
+ "S3DataType": "S3Prefix",
+ "CompressionType": "None",
+ }
+ },
+ "InstanceType": "ml.p2.xlarge",
+ "Environment": {
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "ENDPOINT_SERVER_TIMEOUT": "3600",
+ "MODEL_CACHE_ROOT": "/opt/ml/model",
+ "SAGEMAKER_ENV": "1",
+ "HF_MODEL_ID": "/opt/ml/model",
+ "MAX_INPUT_LENGTH": "4095",
+ "MAX_TOTAL_TOKENS": "4096",
+ "SM_NUM_GPUS": "1",
+ "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
+ },
+ "ComputeResourceRequirements": {
+ "MinMemoryRequiredInMb": 16384,
+ "NumberOfAcceleratorDevicesRequired": 1,
+ },
+ },
+ },
+ {
+ "ConfigName": "gpu-inference",
+ "BenchmarkMetrics": [
+ {"name": "Latency", "value": "100", "unit": "Tokens/S"},
+ {"name": "Throughput", "value": "1867", "unit": "Tokens/S"},
+ ],
+ "DeploymentConfig": {
+ "ModelDataDownloadTimeout": 1200,
+ "ContainerStartupHealthCheckTimeout": 1200,
+ "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4"
+ ".0-gpu-py310-cu121-ubuntu20.04",
+ "ModelData": {
+ "S3DataSource": {
+ "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration"
+ "-llama-2-7b/artifacts/inference-prepack/v1.0.0/",
+ "S3DataType": "S3Prefix",
+ "CompressionType": "None",
+ }
+ },
+ "InstanceType": "ml.p2.xlarge",
+ "Environment": {
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "ENDPOINT_SERVER_TIMEOUT": "3600",
+ "MODEL_CACHE_ROOT": "/opt/ml/model",
+ "SAGEMAKER_ENV": "1",
+ "HF_MODEL_ID": "/opt/ml/model",
+ "MAX_INPUT_LENGTH": "4095",
+ "MAX_TOTAL_TOKENS": "4096",
+ "SM_NUM_GPUS": "1",
+ "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
+ },
+ "ComputeResourceRequirements": {
+ "MinMemoryRequiredInMb": 16384,
+ "NumberOfAcceleratorDevicesRequired": 1,
+ },
+ },
+ },
+]
diff --git a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py
index 154b6b7d95..23d1315647 100644
--- a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py
+++ b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py
@@ -13,6 +13,7 @@
from __future__ import absolute_import
import os
+from pathlib import Path
from unittest.mock import patch, MagicMock, mock_open
import pytest
@@ -21,6 +22,7 @@
from sagemaker.serve import ModelServer
from sagemaker.serve.model_format.mlflow.constants import (
MLFLOW_PYFUNC,
+ TENSORFLOW_SAVED_MODEL_NAME,
)
from sagemaker.serve.model_format.mlflow.utils import (
_get_default_model_server_for_mlflow,
@@ -35,6 +37,8 @@
_select_container_for_mlflow_model,
_validate_input_for_mlflow,
_copy_directory_contents,
+ _move_contents,
+ _get_saved_model_path_for_tensorflow_and_keras_flavor,
)
@@ -414,11 +418,71 @@ def test_select_container_for_mlflow_model_no_dlc_detected(
)
+@patch("sagemaker.image_uris.retrieve")
+@patch("sagemaker.serve.model_format.mlflow.utils._cast_to_compatible_version")
+@patch("sagemaker.serve.model_format.mlflow.utils._get_framework_version_from_requirements")
+@patch(
+ "sagemaker.serve.model_format.mlflow.utils._get_python_version_from_parsed_mlflow_model_file"
+)
+@patch("sagemaker.serve.model_format.mlflow.utils._get_all_flavor_metadata")
+@patch("sagemaker.serve.model_format.mlflow.utils._generate_mlflow_artifact_path")
+def test_select_container_for_mlflow_model_no_framework_version_detected(
+ mock_generate_mlflow_artifact_path,
+ mock_get_all_flavor_metadata,
+ mock_get_python_version_from_parsed_mlflow_model_file,
+ mock_get_framework_version_from_requirements,
+ mock_cast_to_compatible_version,
+ mock_image_uris_retrieve,
+):
+ mlflow_model_src_path = "/path/to/mlflow_model"
+ deployment_flavor = "pytorch"
+ region = "us-west-2"
+ instance_type = "ml.m5.xlarge"
+
+ mock_requirements_path = "/path/to/requirements.txt"
+ mock_metadata_path = "/path/to/mlmodel"
+ mock_flavor_metadata = {"pytorch": {"some_key": "some_value"}}
+ mock_python_version = "3.8.6"
+
+ mock_generate_mlflow_artifact_path.side_effect = lambda path, artifact: (
+ mock_requirements_path if artifact == "requirements.txt" else mock_metadata_path
+ )
+ mock_get_all_flavor_metadata.return_value = mock_flavor_metadata
+ mock_get_python_version_from_parsed_mlflow_model_file.return_value = mock_python_version
+ mock_get_framework_version_from_requirements.return_value = None
+
+ with pytest.raises(
+ ValueError,
+ match="Unable to auto detect framework version. Please provide framework "
+ "pytorch as part of the requirements.txt file for deployment flavor "
+ "pytorch",
+ ):
+ _select_container_for_mlflow_model(
+ mlflow_model_src_path, deployment_flavor, region, instance_type
+ )
+
+ mock_generate_mlflow_artifact_path.assert_any_call(
+ mlflow_model_src_path, "requirements.txt"
+ )
+ mock_generate_mlflow_artifact_path.assert_any_call(mlflow_model_src_path, "MLmodel")
+ mock_get_all_flavor_metadata.assert_called_once_with(mock_metadata_path)
+ mock_get_framework_version_from_requirements.assert_called_once_with(
+ deployment_flavor, mock_requirements_path
+ )
+ mock_cast_to_compatible_version.assert_not_called()
+ mock_image_uris_retrieve.assert_not_called()
+
+
def test_validate_input_for_mlflow():
- _validate_input_for_mlflow(ModelServer.TORCHSERVE)
+ _validate_input_for_mlflow(ModelServer.TORCHSERVE, "pytorch")
with pytest.raises(ValueError):
- _validate_input_for_mlflow(ModelServer.DJL_SERVING)
+ _validate_input_for_mlflow(ModelServer.DJL_SERVING, "pytorch")
+
+
+def test_validate_input_for_mlflow_non_supported_flavor_with_tf_serving():
+ with pytest.raises(ValueError):
+ _validate_input_for_mlflow(ModelServer.TENSORFLOW_SERVING, "pytorch")
@patch("sagemaker.serve.model_format.mlflow.utils.shutil.copy2")
@@ -472,3 +536,68 @@ def test_copy_directory_contents_handles_same_src_dst(
mock_os_walk.assert_not_called()
mock_os_makedirs.assert_not_called()
mock_shutil_copy2.assert_not_called()
+
+
+@patch("os.path.abspath")
+@patch("os.walk")
+def test_get_saved_model_path_found(mock_os_walk, mock_os_abspath):
+ mock_os_walk.return_value = [
+ ("/root/folder1", ("subfolder",), ()),
+ ("/root/folder1/subfolder", (), (TENSORFLOW_SAVED_MODEL_NAME,)),
+ ]
+ expected_path = "/root/folder1/subfolder"
+ mock_os_abspath.return_value = expected_path
+
+ # Call the function
+ result = _get_saved_model_path_for_tensorflow_and_keras_flavor("/root/folder1")
+
+ # Assertions
+ mock_os_walk.assert_called_once_with("/root/folder1")
+ mock_os_abspath.assert_called_once_with("/root/folder1/subfolder")
+ assert result == expected_path
+
+
+@patch("os.path.abspath")
+@patch("os.walk")
+def test_get_saved_model_path_not_found(mock_os_walk, mock_os_abspath):
+ mock_os_walk.return_value = [
+ ("/root/folder2", ("subfolder",), ()),
+ ("/root/folder2/subfolder", (), ("not_saved_model.pb",)),
+ ]
+
+ result = _get_saved_model_path_for_tensorflow_and_keras_flavor("/root/folder2")
+
+ mock_os_walk.assert_called_once_with("/root/folder2")
+ mock_os_abspath.assert_not_called()
+ assert result is None
+
+
+@patch("sagemaker.serve.model_format.mlflow.utils.shutil.move")
+@patch("sagemaker.serve.model_format.mlflow.utils.Path.iterdir")
+@patch("sagemaker.serve.model_format.mlflow.utils.Path.mkdir")
+def test_move_contents_handles_same_src_dst(mock_mkdir, mock_iterdir, mock_shutil_move):
+ src_dir = "/fake/source/dir"
+ dest_dir = "/fake/source/./dir"
+
+ mock_iterdir.return_value = []
+
+ _move_contents(src_dir, dest_dir)
+
+ mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)
+ mock_shutil_move.assert_not_called()
+
+
+@patch("sagemaker.serve.model_format.mlflow.utils.shutil.move")
+@patch("sagemaker.serve.model_format.mlflow.utils.Path.iterdir")
+@patch("sagemaker.serve.model_format.mlflow.utils.Path.mkdir")
+def test_move_contents_with_actual_files(mock_mkdir, mock_iterdir, mock_shutil_move):
+ src_dir = Path("/fake/source/dir")
+ dest_dir = Path("/fake/destination/dir")
+
+ file_path = src_dir / "testfile.txt"
+ mock_iterdir.return_value = [file_path]
+
+ _move_contents(src_dir, dest_dir)
+
+ mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)
+ mock_shutil_move.assert_called_once_with(str(file_path), str(dest_dir / "testfile.txt"))
diff --git a/tests/unit/sagemaker/serve/model_server/tei/__init__.py b/tests/unit/sagemaker/serve/model_server/tei/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/unit/sagemaker/serve/model_server/tei/test_server.py b/tests/unit/sagemaker/serve/model_server/tei/test_server.py
new file mode 100644
index 0000000000..16dcf12b5a
--- /dev/null
+++ b/tests/unit/sagemaker/serve/model_server/tei/test_server.py
@@ -0,0 +1,150 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+from pathlib import PosixPath
+from unittest import TestCase
+from unittest.mock import Mock, patch
+
+from docker.types import DeviceRequest
+from sagemaker.serve.model_server.tei.server import LocalTeiServing, SageMakerTeiServing
+from sagemaker.serve.utils.exceptions import LocalModelInvocationException
+
+TEI_IMAGE = (
+ "246618743249.dkr.ecr.us-west-2.amazonaws.com/tei:2.0.1-tei1.2.3-gpu-py310-cu122-ubuntu22.04"
+)
+MODEL_PATH = "model_path"
+ENV_VAR = {"KEY": "VALUE"}
+PAYLOAD = {
+ "inputs": {
+ "sourceSentence": "How cute your dog is!",
+ "sentences": ["The mitochondria is the powerhouse of the cell.", "Your dog is so cute."],
+ }
+}
+S3_URI = "s3://mock_model_data_uri"
+SECRET_KEY = "secret_key"
+INFER_RESPONSE = []
+
+
+class TeiServerTests(TestCase):
+ @patch("sagemaker.serve.model_server.tei.server.requests")
+ def test_start_invoke_destroy_local_tei_server(self, mock_requests):
+ mock_container = Mock()
+ mock_docker_client = Mock()
+ mock_docker_client.containers.run.return_value = mock_container
+
+ local_tei_server = LocalTeiServing()
+ mock_schema_builder = Mock()
+ mock_schema_builder.input_serializer.serialize.return_value = PAYLOAD
+ local_tei_server.schema_builder = mock_schema_builder
+
+ local_tei_server._start_tei_serving(
+ client=mock_docker_client,
+ model_path=MODEL_PATH,
+ secret_key=SECRET_KEY,
+ image=TEI_IMAGE,
+ env_vars=ENV_VAR,
+ )
+
+ mock_docker_client.containers.run.assert_called_once_with(
+ TEI_IMAGE,
+ shm_size="2G",
+ device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
+ network_mode="host",
+ detach=True,
+ auto_remove=True,
+ volumes={PosixPath("model_path/code"): {"bind": "/opt/ml/model/", "mode": "rw"}},
+ environment={
+ "TRANSFORMERS_CACHE": "/opt/ml/model/",
+ "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/",
+ "KEY": "VALUE",
+ "SAGEMAKER_SERVE_SECRET_KEY": "secret_key",
+ },
+ )
+
+ mock_response = Mock()
+ mock_requests.post.side_effect = lambda *args, **kwargs: mock_response
+ mock_response.content = INFER_RESPONSE
+
+ res = local_tei_server._invoke_tei_serving(
+ request=PAYLOAD, content_type="application/json", accept="application/json"
+ )
+
+ self.assertEqual(res, INFER_RESPONSE)
+
+ def test_tei_deep_ping(self):
+ mock_predictor = Mock()
+ mock_response = Mock()
+ mock_schema_builder = Mock()
+
+ mock_predictor.predict.side_effect = lambda *args, **kwargs: mock_response
+ mock_schema_builder.sample_input = PAYLOAD
+
+ local_tei_server = LocalTeiServing()
+ local_tei_server.schema_builder = mock_schema_builder
+ res = local_tei_server._tei_deep_ping(mock_predictor)
+
+ self.assertEqual(res, (True, mock_response))
+
+ def test_tei_deep_ping_invoke_ex(self):
+ mock_predictor = Mock()
+ mock_schema_builder = Mock()
+
+ mock_predictor.predict.side_effect = lambda *args, **kwargs: exec(
+ 'raise(ValueError("422 Client Error: Unprocessable Entity for url:"))'
+ )
+ mock_schema_builder.sample_input = PAYLOAD
+
+ local_tei_server = LocalTeiServing()
+ local_tei_server.schema_builder = mock_schema_builder
+
+ self.assertRaises(
+ LocalModelInvocationException, lambda: local_tei_server._tei_deep_ping(mock_predictor)
+ )
+
+ def test_tei_deep_ping_ex(self):
+ mock_predictor = Mock()
+
+ mock_predictor.predict.side_effect = lambda *args, **kwargs: Exception()
+
+ local_tei_server = LocalTeiServing()
+ res = local_tei_server._tei_deep_ping(mock_predictor)
+
+ self.assertEqual(res, (False, None))
+
+ @patch("sagemaker.serve.model_server.tei.server.S3Uploader")
+ def test_upload_artifacts_sagemaker_tei_server(self, mock_uploader):
+ mock_session = Mock()
+ mock_uploader.upload.side_effect = (
+ lambda *args, **kwargs: "s3://sagemaker-us-west-2-123456789123/tei-2024-05-20-16-05-36-027/code"
+ )
+
+ s3_upload_path, env_vars = SageMakerTeiServing()._upload_tei_artifacts(
+ model_path=MODEL_PATH,
+ sagemaker_session=mock_session,
+ s3_model_data_url=S3_URI,
+ image=TEI_IMAGE,
+ )
+
+ mock_uploader.upload.assert_called_once()
+ self.assertEqual(
+ s3_upload_path,
+ {
+ "S3DataSource": {
+ "CompressionType": "None",
+ "S3DataType": "S3Prefix",
+ "S3Uri": "s3://sagemaker-us-west-2-123456789123/tei-2024-05-20-16-05-36-027/code/",
+ }
+ },
+ )
+ self.assertIsNotNone(env_vars)
diff --git a/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_prepare.py b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_prepare.py
new file mode 100644
index 0000000000..9915b19649
--- /dev/null
+++ b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_prepare.py
@@ -0,0 +1,116 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+from unittest import TestCase
+from unittest.mock import Mock, patch, mock_open
+import pytest
+
+from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving
+
+MODEL_PATH = "/path/to/your/model/dir"
+SHARED_LIBS = ["/path/to/shared/libs"]
+DEPENDENCIES = {"dependencies": "requirements.txt"}
+INFERENCE_SPEC = Mock()
+IMAGE_URI = "mock_image_uri"
+XGB_1P_IMAGE_URI = "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.7-1"
+INFERENCE_SPEC.prepare = Mock(return_value=None)
+
+SECRET_KEY = "secret-key"
+
+mock_session = Mock()
+
+
+class PrepareForTensorflowServingTests(TestCase):
+ def setUp(self):
+ INFERENCE_SPEC.reset_mock()
+
+ @patch("builtins.open", new_callable=mock_open, read_data=b"{}")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents")
+ @patch(
+ "sagemaker.serve.model_server.tensorflow_serving.prepare."
+ "_get_saved_model_path_for_tensorflow_and_keras_flavor"
+ )
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._MetaData")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.shutil")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.Path")
+ def test_prepare_happy(
+ self,
+ mock_path,
+ mock_shutil,
+ mock_capture_dependencies,
+ mock_generate_secret_key,
+ mock_compute_hash,
+ mock_metadata,
+ mock_get_saved_model_path,
+ mock_move_contents,
+ mock_open,
+ ):
+
+ mock_path_instance = mock_path.return_value
+ mock_path_instance.exists.return_value = True
+ mock_path_instance.joinpath.return_value = Mock()
+ mock_get_saved_model_path.return_value = MODEL_PATH + "/1/"
+
+ mock_generate_secret_key.return_value = SECRET_KEY
+
+ secret_key = prepare_for_tf_serving(
+ model_path=MODEL_PATH,
+ shared_libs=SHARED_LIBS,
+ dependencies=DEPENDENCIES,
+ )
+
+ mock_path_instance.mkdir.assert_not_called()
+ self.assertEqual(secret_key, SECRET_KEY)
+
+ @patch("builtins.open", new_callable=mock_open, read_data=b"{}")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents")
+ @patch(
+ "sagemaker.serve.model_server.tensorflow_serving.prepare."
+ "_get_saved_model_path_for_tensorflow_and_keras_flavor"
+ )
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._MetaData")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.shutil")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.Path")
+ def test_prepare_saved_model_not_found(
+ self,
+ mock_path,
+ mock_shutil,
+ mock_capture_dependencies,
+ mock_generate_secret_key,
+ mock_compute_hash,
+ mock_metadata,
+ mock_get_saved_model_path,
+ mock_move_contents,
+ mock_open,
+ ):
+
+ mock_path_instance = mock_path.return_value
+ mock_path_instance.exists.return_value = True
+ mock_path_instance.joinpath.return_value = Mock()
+ mock_get_saved_model_path.return_value = None
+
+ with pytest.raises(
+ ValueError, match="SavedModel is not found for Tensorflow or Keras flavor."
+ ):
+ prepare_for_tf_serving(
+ model_path=MODEL_PATH,
+ shared_libs=SHARED_LIBS,
+ dependencies=DEPENDENCIES,
+ )
diff --git a/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py
new file mode 100644
index 0000000000..3d3bac0935
--- /dev/null
+++ b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py
@@ -0,0 +1,100 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+from pathlib import PosixPath
+import platform
+from unittest import TestCase
+from unittest.mock import Mock, patch, ANY
+
+import numpy as np
+
+from sagemaker.serve.model_server.tensorflow_serving.server import (
+ LocalTensorflowServing,
+ SageMakerTensorflowServing,
+)
+
+CPU_TF_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:2.14.1-cpu"
+MODEL_PATH = "model_path"
+MODEL_REPO = f"{MODEL_PATH}/1"
+ENV_VAR = {"KEY": "VALUE"}
+_SHM_SIZE = "2G"
+PAYLOAD = np.random.rand(3, 4).astype(dtype=np.float32)
+S3_URI = "s3://mock_model_data_uri"
+DTYPE = "TYPE_FP32"
+SECRET_KEY = "secret_key"
+
+INFER_RESPONSE = {"outputs": [{"name": "output_name"}]}
+
+
+class TensorflowservingServerTests(TestCase):
+ def test_start_invoke_destroy_local_tensorflow_serving_server(self):
+ mock_container = Mock()
+ mock_docker_client = Mock()
+ mock_docker_client.containers.run.return_value = mock_container
+
+ local_tensorflow_server = LocalTensorflowServing()
+ mock_schema_builder = Mock()
+ mock_schema_builder.input_serializer.serialize.return_value = PAYLOAD
+ local_tensorflow_server.schema_builder = mock_schema_builder
+
+ local_tensorflow_server._start_tensorflow_serving(
+ client=mock_docker_client,
+ model_path=MODEL_PATH,
+ secret_key=SECRET_KEY,
+ env_vars=ENV_VAR,
+ image=CPU_TF_IMAGE,
+ )
+
+ mock_docker_client.containers.run.assert_called_once_with(
+ CPU_TF_IMAGE,
+ "serve",
+ detach=True,
+ auto_remove=True,
+ network_mode="host",
+ volumes={PosixPath("model_path"): {"bind": "/opt/ml/model", "mode": "rw"}},
+ environment={
+ "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
+ "SAGEMAKER_PROGRAM": "inference.py",
+ "SAGEMAKER_SERVE_SECRET_KEY": "secret_key",
+ "LOCAL_PYTHON": platform.python_version(),
+ "KEY": "VALUE",
+ },
+ )
+
+ @patch("sagemaker.serve.model_server.tensorflow_serving.server.platform")
+ @patch("sagemaker.serve.model_server.tensorflow_serving.server.upload")
+ def test_upload_artifacts_sagemaker_triton_server(self, mock_upload, mock_platform):
+ mock_session = Mock()
+ mock_platform.python_version.return_value = "3.8"
+ mock_upload.side_effect = lambda session, repo, bucket, prefix: (
+ S3_URI
+ if session == mock_session and repo == MODEL_PATH and bucket == "mock_model_data_uri"
+ else None
+ )
+
+ (
+ s3_upload_path,
+ env_vars,
+ ) = SageMakerTensorflowServing()._upload_tensorflow_serving_artifacts(
+ model_path=MODEL_PATH,
+ sagemaker_session=mock_session,
+ secret_key=SECRET_KEY,
+ s3_model_data_url=S3_URI,
+ image=CPU_TF_IMAGE,
+ )
+
+ mock_upload.assert_called_once_with(mock_session, MODEL_PATH, "mock_model_data_uri", ANY)
+ self.assertEqual(s3_upload_path, S3_URI)
+ self.assertEqual(env_vars.get("SAGEMAKER_SERVE_SECRET_KEY"), SECRET_KEY)
+ self.assertEqual(env_vars.get("LOCAL_PYTHON"), "3.8")
diff --git a/tests/unit/sagemaker/serve/utils/test_lineage_utils.py b/tests/unit/sagemaker/serve/utils/test_lineage_utils.py
new file mode 100644
index 0000000000..25e4fe246e
--- /dev/null
+++ b/tests/unit/sagemaker/serve/utils/test_lineage_utils.py
@@ -0,0 +1,374 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+from unittest.mock import call
+
+import pytest
+from botocore.exceptions import ClientError
+from mock import Mock, patch
+from sagemaker import Session
+from sagemaker.lineage.artifact import ArtifactSummary, Artifact
+from sagemaker.lineage.query import LineageSourceEnum
+
+from sagemaker.serve.utils.lineage_constants import (
+ MLFLOW_RUN_ID,
+ MLFLOW_MODEL_PACKAGE_PATH,
+ MLFLOW_S3_PATH,
+ MLFLOW_LOCAL_PATH,
+ LINEAGE_POLLER_MAX_TIMEOUT_SECS,
+ MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
+ CONTRIBUTED_TO,
+ MLFLOW_REGISTRY_PATH,
+)
+from sagemaker.serve.utils.lineage_utils import (
+ _load_artifact_by_source_uri,
+ _poll_lineage_artifact,
+ _get_mlflow_model_path_type,
+ _create_mlflow_model_path_lineage_artifact,
+ _add_association_between_artifacts,
+ _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact,
+ _maintain_lineage_tracking_for_mlflow_model,
+)
+
+
+@patch("sagemaker.lineage.artifact.Artifact.list")
+def test_load_artifact_by_source_uri(mock_artifact_list):
+ source_uri = "s3://mybucket/mymodel"
+ sagemaker_session = Mock(spec=Session)
+
+ mock_artifact_1 = Mock(spec=ArtifactSummary)
+ mock_artifact_1.artifact_type = LineageSourceEnum.MODEL_DATA.value
+ mock_artifact_2 = Mock(spec=ArtifactSummary)
+ mock_artifact_2.artifact_type = LineageSourceEnum.IMAGE.value
+ mock_artifacts = [mock_artifact_1, mock_artifact_2]
+ mock_artifact_list.return_value = mock_artifacts
+
+ result = _load_artifact_by_source_uri(
+ source_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session
+ )
+
+ mock_artifact_list.assert_called_once_with(
+ source_uri=source_uri, sagemaker_session=sagemaker_session
+ )
+ assert result == mock_artifact_1
+
+
+@patch("sagemaker.lineage.artifact.Artifact.list")
+def test_load_artifact_by_source_uri_no_match(mock_artifact_list):
+ source_uri = "s3://mybucket/mymodel"
+ sagemaker_session = Mock(spec=Session)
+
+ mock_artifact_1 = Mock(spec=ArtifactSummary)
+ mock_artifact_1.artifact_type = LineageSourceEnum.IMAGE.value
+ mock_artifact_2 = Mock(spec=ArtifactSummary)
+ mock_artifact_2.artifact_type = LineageSourceEnum.IMAGE.value
+ mock_artifacts = [mock_artifact_1, mock_artifact_2]
+ mock_artifact_list.return_value = mock_artifacts
+
+ result = _load_artifact_by_source_uri(
+ source_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session
+ )
+
+ mock_artifact_list.assert_called_once_with(
+ source_uri=source_uri, sagemaker_session=sagemaker_session
+ )
+ assert result is None
+
+
+@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri")
+def test_poll_lineage_artifact_found(mock_load_artifact):
+ s3_uri = "s3://mybucket/mymodel"
+ sagemaker_session = Mock(spec=Session)
+ mock_artifact = Mock(spec=ArtifactSummary)
+
+ with patch("time.time") as mock_time:
+ mock_time.return_value = 0
+
+ mock_load_artifact.return_value = mock_artifact
+
+ result = _poll_lineage_artifact(
+ s3_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session
+ )
+
+ assert result == mock_artifact
+ mock_load_artifact.assert_has_calls(
+ [
+ call(s3_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session),
+ ]
+ )
+
+
+@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri")
+def test_poll_lineage_artifact_not_found(mock_load_artifact):
+ s3_uri = "s3://mybucket/mymodel"
+ artifact_type = LineageSourceEnum.MODEL_DATA.value
+ sagemaker_session = Mock(spec=Session)
+
+ with patch("time.time") as mock_time:
+ mock_time_values = [0.0, 1.0, LINEAGE_POLLER_MAX_TIMEOUT_SECS + 1.0]
+ mock_time.side_effect = mock_time_values
+
+ with patch("time.sleep"):
+ mock_load_artifact.side_effect = [None, None, None]
+
+ result = _poll_lineage_artifact(s3_uri, artifact_type, sagemaker_session)
+
+ assert result is None
+
+
+@pytest.mark.parametrize(
+ "mlflow_model_path, expected_output",
+ [
+ ("runs:/abc123", MLFLOW_RUN_ID),
+ ("models:/my-model/1", MLFLOW_REGISTRY_PATH),
+ (
+ "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model-package",
+ MLFLOW_MODEL_PACKAGE_PATH,
+ ),
+ ("s3://my-bucket/path/to/model", MLFLOW_S3_PATH),
+ ],
+)
+def test_get_mlflow_model_path_type_valid(mlflow_model_path, expected_output):
+ result = _get_mlflow_model_path_type(mlflow_model_path)
+ assert result == expected_output
+
+
+@patch("os.path.exists")
+def test_get_mlflow_model_path_type_valid_local_path(mock_path_exists):
+ valid_path = "/path/to/mlflow_model"
+ mock_path_exists.side_effect = lambda path: path == valid_path
+ result = _get_mlflow_model_path_type(valid_path)
+ assert result == MLFLOW_LOCAL_PATH
+
+
+def test_get_mlflow_model_path_type_invalid():
+ invalid_path = "invalid_path"
+ with pytest.raises(ValueError, match=f"Invalid MLflow model path: {invalid_path}"):
+ _get_mlflow_model_path_type(invalid_path)
+
+
+@patch("sagemaker.serve.utils.lineage_utils._get_mlflow_model_path_type")
+@patch("sagemaker.lineage.artifact.Artifact.create")
+def test_create_mlflow_model_path_lineage_artifact_success(
+ mock_artifact_create, mock_get_mlflow_path_type
+):
+ mlflow_model_path = "runs:/Ab12Cd34"
+ sagemaker_session = Mock(spec=Session)
+ mock_artifact = Mock(spec=Artifact)
+ mock_get_mlflow_path_type.return_value = "mlflow_run_id"
+ mock_artifact_create.return_value = mock_artifact
+
+ result = _create_mlflow_model_path_lineage_artifact(mlflow_model_path, sagemaker_session)
+
+ assert result == mock_artifact
+ mock_get_mlflow_path_type.assert_called_once_with(mlflow_model_path)
+ mock_artifact_create.assert_called_once_with(
+ source_uri=mlflow_model_path,
+ artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
+ artifact_name="mlflow_run_id",
+ properties={"model_builder_input_model_data_type": "mlflow_run_id"},
+ sagemaker_session=sagemaker_session,
+ )
+
+
+@patch("sagemaker.serve.utils.lineage_utils._get_mlflow_model_path_type")
+@patch("sagemaker.lineage.artifact.Artifact.create")
+def test_create_mlflow_model_path_lineage_artifact_validation_exception(
+ mock_artifact_create, mock_get_mlflow_path_type
+):
+ mlflow_model_path = "runs:/Ab12Cd34"
+ sagemaker_session = Mock(spec=Session)
+ mock_get_mlflow_path_type.return_value = "mlflow_run_id"
+ mock_artifact_create.side_effect = ClientError(
+ error_response={"Error": {"Code": "ValidationException"}}, operation_name="CreateArtifact"
+ )
+
+ result = _create_mlflow_model_path_lineage_artifact(mlflow_model_path, sagemaker_session)
+
+ assert result is None
+
+
+@patch("sagemaker.serve.utils.lineage_utils._get_mlflow_model_path_type")
+@patch("sagemaker.lineage.artifact.Artifact.create")
+def test_create_mlflow_model_path_lineage_artifact_other_exception(
+ mock_artifact_create, mock_get_mlflow_path_type
+):
+ mlflow_model_path = "runs:/Ab12Cd34"
+ sagemaker_session = Mock(spec=Session)
+ mock_get_mlflow_path_type.return_value = "mlflow_run_id"
+ mock_artifact_create.side_effect = ClientError(
+ error_response={"Error": {"Code": "SomeOtherException"}}, operation_name="CreateArtifact"
+ )
+
+ with pytest.raises(ClientError):
+ _create_mlflow_model_path_lineage_artifact(mlflow_model_path, sagemaker_session)
+
+
+@patch("sagemaker.serve.utils.lineage_utils._create_mlflow_model_path_lineage_artifact")
+@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri")
+def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_existing(
+ mock_load_artifact, mock_create_artifact
+):
+ mlflow_model_path = "runs:/Ab12Cd34"
+ sagemaker_session = Mock(spec=Session)
+ mock_artifact_summary = Mock(spec=ArtifactSummary)
+ mock_load_artifact.return_value = mock_artifact_summary
+
+ result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact(
+ mlflow_model_path, sagemaker_session
+ )
+
+ assert result == mock_artifact_summary
+ mock_load_artifact.assert_called_once_with(
+ mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session
+ )
+ mock_create_artifact.assert_not_called()
+
+
+@patch("sagemaker.serve.utils.lineage_utils._create_mlflow_model_path_lineage_artifact")
+@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri")
+def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_create(
+ mock_load_artifact, mock_create_artifact
+):
+ mlflow_model_path = "runs:/Ab12Cd34"
+ sagemaker_session = Mock(spec=Session)
+ mock_artifact = Mock(spec=Artifact)
+ mock_load_artifact.return_value = None
+ mock_create_artifact.return_value = mock_artifact
+
+ result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact(
+ mlflow_model_path, sagemaker_session
+ )
+
+ assert result == mock_artifact
+ mock_load_artifact.assert_called_once_with(
+ mlflow_model_path, MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, sagemaker_session
+ )
+ mock_create_artifact.assert_called_once_with(mlflow_model_path, sagemaker_session)
+
+
+@patch("sagemaker.lineage.association.Association.create")
+def test_add_association_between_artifacts_success(mock_association_create):
+ mlflow_model_path_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/123"
+ autogenerated_model_data_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/456"
+ sagemaker_session = Mock(spec=Session)
+
+ _add_association_between_artifacts(
+ mlflow_model_path_artifact_arn,
+ autogenerated_model_data_artifact_arn,
+ sagemaker_session,
+ )
+
+ mock_association_create.assert_called_once_with(
+ source_arn=mlflow_model_path_artifact_arn,
+ destination_arn=autogenerated_model_data_artifact_arn,
+ association_type=CONTRIBUTED_TO,
+ sagemaker_session=sagemaker_session,
+ )
+
+
+@patch("sagemaker.lineage.association.Association.create")
+def test_add_association_between_artifacts_validation_exception(mock_association_create):
+ mlflow_model_path_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/123"
+ autogenerated_model_data_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/456"
+ sagemaker_session = Mock(spec=Session)
+ mock_association_create.side_effect = ClientError(
+ error_response={"Error": {"Code": "ValidationException"}},
+ operation_name="CreateAssociation",
+ )
+
+ _add_association_between_artifacts(
+ mlflow_model_path_artifact_arn,
+ autogenerated_model_data_artifact_arn,
+ sagemaker_session,
+ )
+
+
+@patch("sagemaker.lineage.association.Association.create")
+def test_add_association_between_artifacts_other_exception(mock_association_create):
+ mlflow_model_path_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/123"
+ autogenerated_model_data_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/456"
+ sagemaker_session = Mock(spec=Session)
+ mock_association_create.side_effect = ClientError(
+ error_response={"Error": {"Code": "SomeOtherException"}}, operation_name="CreateAssociation"
+ )
+
+ with pytest.raises(ClientError):
+ _add_association_between_artifacts(
+ mlflow_model_path_artifact_arn,
+ autogenerated_model_data_artifact_arn,
+ sagemaker_session,
+ )
+
+
+@patch("sagemaker.serve.utils.lineage_utils._poll_lineage_artifact")
+@patch(
+ "sagemaker.serve.utils.lineage_utils._retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact"
+)
+@patch("sagemaker.serve.utils.lineage_utils._add_association_between_artifacts")
+def test_maintain_lineage_tracking_for_mlflow_model_success(
+ mock_add_association, mock_retrieve_create_artifact, mock_poll_artifact
+):
+ mlflow_model_path = "runs:/Ab12Cd34"
+ s3_upload_path = "s3://mybucket/path/to/model"
+ sagemaker_session = Mock(spec=Session)
+ mock_model_data_artifact = Mock(spec=ArtifactSummary)
+ mock_mlflow_model_artifact = Mock(spec=Artifact)
+ mock_poll_artifact.return_value = mock_model_data_artifact
+ mock_retrieve_create_artifact.return_value = mock_mlflow_model_artifact
+
+ _maintain_lineage_tracking_for_mlflow_model(
+ mlflow_model_path, s3_upload_path, sagemaker_session
+ )
+
+ mock_poll_artifact.assert_called_once_with(
+ s3_uri=s3_upload_path,
+ artifact_type=LineageSourceEnum.MODEL_DATA.value,
+ sagemaker_session=sagemaker_session,
+ )
+ mock_retrieve_create_artifact.assert_called_once_with(
+ mlflow_model_path=mlflow_model_path, sagemaker_session=sagemaker_session
+ )
+ mock_add_association.assert_called_once_with(
+ mlflow_model_path_artifact_arn=mock_mlflow_model_artifact.artifact_arn,
+ autogenerated_model_data_artifact_arn=mock_model_data_artifact.artifact_arn,
+ sagemaker_session=sagemaker_session,
+ )
+
+
+@patch("sagemaker.serve.utils.lineage_utils._poll_lineage_artifact")
+@patch(
+ "sagemaker.serve.utils.lineage_utils._retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact"
+)
+@patch("sagemaker.serve.utils.lineage_utils._add_association_between_artifacts")
+def test_maintain_lineage_tracking_for_mlflow_model_no_model_data_artifact(
+ mock_add_association, mock_retrieve_create_artifact, mock_poll_artifact
+):
+ mlflow_model_path = "runs:/Ab12Cd34"
+ s3_upload_path = "s3://mybucket/path/to/model"
+ sagemaker_session = Mock(spec=Session)
+ mock_poll_artifact.return_value = None
+ mock_retrieve_create_artifact.return_value = None
+
+ _maintain_lineage_tracking_for_mlflow_model(
+ mlflow_model_path, s3_upload_path, sagemaker_session
+ )
+
+ mock_poll_artifact.assert_called_once_with(
+ s3_uri=s3_upload_path,
+ artifact_type=LineageSourceEnum.MODEL_DATA.value,
+ sagemaker_session=sagemaker_session,
+ )
+ mock_retrieve_create_artifact.assert_not_called()
+ mock_add_association.assert_not_called()
diff --git a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py
index e8273dd9a1..33af575e8f 100644
--- a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py
+++ b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py
@@ -14,6 +14,7 @@
import unittest
from unittest.mock import Mock, patch
from sagemaker.serve import Mode, ModelServer
+from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH
from sagemaker.serve.utils.telemetry_logger import (
_send_telemetry,
_capture_telemetry,
@@ -32,9 +33,13 @@
"763104351884.dkr.ecr.us-east-1.amazonaws.com/"
"huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04"
)
+MOCK_PYTORCH_CONTAINER = (
+ "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310"
+)
MOCK_HUGGINGFACE_ID = "meta-llama/Llama-2-7b-hf"
MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex")
MOCK_ENDPOINT_ARN = "arn:aws:sagemaker:us-west-2:123456789012:endpoint/test"
+MOCK_MODEL_METADATA_FOR_MLFLOW = {MLFLOW_MODEL_PATH: "s3://some_path"}
class ModelBuilderMock:
@@ -239,3 +244,34 @@ def test_construct_url_with_failure_reason_and_extra_info(self):
f"&x-extra={mock_extra_info}"
)
self.assertEquals(ret_url, expected_base_url)
+
+ @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry")
+ def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry):
+ mock_model_builder = ModelBuilderMock()
+ mock_model_builder.serve_settings.telemetry_opt_out = False
+ mock_model_builder.image_uri = MOCK_PYTORCH_CONTAINER
+ mock_model_builder._is_mlflow_model = True
+ mock_model_builder.model_metadata = MOCK_MODEL_METADATA_FOR_MLFLOW
+ mock_model_builder._is_custom_image_uri = False
+ mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT
+ mock_model_builder.model_server = ModelServer.TORCHSERVE
+ mock_model_builder.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN
+
+ mock_model_builder.mock_deploy()
+
+ args = mock_send_telemetry.call_args.args
+ latency = str(args[5]).split("latency=")[1]
+ expected_extra_str = (
+ f"{MOCK_FUNC_NAME}"
+ "&x-modelServer=1"
+ "&x-imageTag=pytorch-inference:2.0.1-cpu-py310"
+ f"&x-sdkVersion={SDK_VERSION}"
+ f"&x-defaultImageUsage={ImageUriOption.DEFAULT_IMAGE.value}"
+ f"&x-endpointArn={MOCK_ENDPOINT_ARN}"
+ f"&x-mlflowModelPathType=2"
+ f"&x-latency={latency}"
+ )
+
+ mock_send_telemetry.assert_called_once_with(
+ "1", 3, MOCK_SESSION, None, None, expected_extra_str
+ )
diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py
index 382c48fde6..fd45601801 100644
--- a/tests/unit/test_estimator.py
+++ b/tests/unit/test_estimator.py
@@ -2089,6 +2089,41 @@ def test_framework_disable_remote_debug(sagemaker_session):
assert len(args) == 2
+def test_framework_with_session_chaining_config(sagemaker_session):
+ f = DummyFramework(
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ sagemaker_session=sagemaker_session,
+ instance_groups=[
+ InstanceGroup("group1", "ml.c4.xlarge", 1),
+ InstanceGroup("group2", "ml.m4.xlarge", 2),
+ ],
+ enable_session_tag_chaining=True,
+ )
+ f.fit("s3://mydata")
+ sagemaker_session.train.assert_called_once()
+ _, args = sagemaker_session.train.call_args
+ assert args["session_chaining_config"]["EnableSessionTagChaining"]
+ assert f.get_session_chaining_config()["EnableSessionTagChaining"]
+
+
+def test_framework_without_session_chaining_config(sagemaker_session):
+ f = DummyFramework(
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ sagemaker_session=sagemaker_session,
+ instance_groups=[
+ InstanceGroup("group1", "ml.c4.xlarge", 1),
+ InstanceGroup("group2", "ml.m4.xlarge", 2),
+ ],
+ )
+ f.fit("s3://mydata")
+ sagemaker_session.train.assert_called_once()
+ _, args = sagemaker_session.train.call_args
+ assert args.get("SessionTagChaining") is None
+ assert f.get_remote_debug_config() is None
+
+
@patch("time.strftime", return_value=TIMESTAMP)
def test_custom_code_bucket(time, sagemaker_session):
code_bucket = "codebucket"
diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py
index e955d68227..97d4e6ec2a 100644
--- a/tests/unit/test_fw_utils.py
+++ b/tests/unit/test_fw_utils.py
@@ -854,17 +854,14 @@ def test_validate_smdataparallel_args_raises():
# Cases {PT|TF2}
# 1. None instance type
- # 2. incorrect instance type
- # 3. incorrect python version
- # 4. incorrect framework version
+ # 2. incorrect python version
+ # 3. incorrect framework version
bad_args = [
(None, "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
- ("ml.p3.2xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
("ml.p3dn.24xlarge", "tensorflow", "2.3.1", "py2", smdataparallel_enabled),
("ml.p3.16xlarge", "tensorflow", "1.3.1", "py3", smdataparallel_enabled),
(None, "pytorch", "1.6.0", "py3", smdataparallel_enabled),
- ("ml.p3.2xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
("ml.p3dn.24xlarge", "pytorch", "1.6.0", "py2", smdataparallel_enabled),
("ml.p3.16xlarge", "pytorch", "1.5.0", "py3", smdataparallel_enabled),
]
@@ -966,74 +963,6 @@ def test_validate_smdataparallel_args_not_raises():
)
-def test_validate_pytorchddp_not_raises():
- # Case 1: Framework is not PyTorch
- fw_utils.validate_pytorch_distribution(
- distribution=None,
- framework_name="tensorflow",
- framework_version="2.9.1",
- py_version="py3",
- image_uri="custom-container",
- )
- # Case 2: Framework is PyTorch, but distribution is not PyTorchDDP
- pytorchddp_disabled = {"pytorchddp": {"enabled": False}}
- fw_utils.validate_pytorch_distribution(
- distribution=pytorchddp_disabled,
- framework_name="pytorch",
- framework_version="1.10",
- py_version="py3",
- image_uri="custom-container",
- )
- # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions
- pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
- pytorchddp_supported_fw_versions = [
- "1.10",
- "1.10.0",
- "1.10.2",
- "1.11",
- "1.11.0",
- "1.12",
- "1.12.0",
- "1.12.1",
- "1.13.1",
- "2.0.0",
- "2.0.1",
- "2.1.0",
- "2.2.0",
- ]
- for framework_version in pytorchddp_supported_fw_versions:
- fw_utils.validate_pytorch_distribution(
- distribution=pytorchddp_enabled,
- framework_name="pytorch",
- framework_version=framework_version,
- py_version="py3",
- image_uri="custom-container",
- )
-
-
-def test_validate_pytorchddp_raises():
- pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
- # Case 1: Unsupported framework version
- with pytest.raises(ValueError):
- fw_utils.validate_pytorch_distribution(
- distribution=pytorchddp_enabled,
- framework_name="pytorch",
- framework_version="1.8",
- py_version="py3",
- image_uri=None,
- )
-
- # Case 2: Unsupported Py version
- with pytest.raises(ValueError):
- fw_utils.validate_pytorch_distribution(
- distribution=pytorchddp_enabled,
- framework_name="pytorch",
- framework_version="1.10",
- py_version="py2",
- image_uri=None,
- )
-
-
def test_validate_torch_distributed_not_raises():
# Case 1: Framework is PyTorch, but torch_distributed is not enabled
torch_distributed_disabled = {"torch_distributed": {"enabled": False}}
diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py
index 5ada026ef8..618d0d7ea8 100644
--- a/tests/unit/test_pytorch.py
+++ b/tests/unit/test_pytorch.py
@@ -801,14 +801,15 @@ def test_pytorch_ddp_distribution_configuration(
distribution=pytorch.distribution
)
expected_torch_ddp = {
- "sagemaker_pytorch_ddp_enabled": True,
+ "sagemaker_distributed_dataparallel_enabled": True,
+ "sagemaker_distributed_dataparallel_custom_mpi_options": "",
"sagemaker_instance_type": test_instance_type,
}
assert actual_pytorch_ddp == expected_torch_ddp
def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session):
- unsupported_framework_version = "1.9.1"
+ unsupported_framework_version = "1.5.0"
unsupported_py_version = "py2"
with pytest.raises(ValueError) as error:
_pytorch_estimator(
diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py
index 19f9d0ae3d..f7dede1ce9 100644
--- a/tests/unit/test_session.py
+++ b/tests/unit/test_session.py
@@ -43,8 +43,6 @@
from sagemaker.utils import update_list_of_dicts_with_values_from_config
from sagemaker.user_agent import (
SDK_PREFIX,
- STUDIO_PREFIX,
- NOTEBOOK_PREFIX,
)
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from tests.unit import (
@@ -87,15 +85,20 @@
limits={},
)
+SDK_DEFAULT_SUFFIX = f"lib/{SDK_PREFIX}#2.218.0"
+NOTEBOOK_SUFFIX = f"{SDK_DEFAULT_SUFFIX} md/AWS-SageMaker-Notebook-Instance#instance_type"
+STUDIO_SUFFIX = f"{SDK_DEFAULT_SUFFIX} md/AWS-SageMaker-Studio#app_type"
-@pytest.fixture()
-def boto_session():
- boto_mock = Mock(name="boto_session", region_name=REGION)
+@pytest.fixture
+def boto_session(request):
+ boto_user_agent = "Boto3/1.33.9 md/Botocore#1.33.9 ua/2.0 os/linux#linux-ver md/arch#x86_64 lang/python#3.10.6"
+ user_agent_suffix = getattr(request, "param", "")
+ boto_mock = Mock(name="boto_session", region_name=REGION)
client_mock = Mock()
- client_mock._client_config.user_agent = (
- "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource"
- )
+ user_agent = f"{boto_user_agent} {SDK_DEFAULT_SUFFIX} {user_agent_suffix}"
+ with patch("sagemaker.user_agent.get_user_agent_extra_suffix", return_value=user_agent_suffix):
+ client_mock._client_config.user_agent = user_agent
boto_mock.client.return_value = client_mock
return boto_mock
@@ -887,65 +890,42 @@ def test_delete_model(boto_session):
boto_session.client().delete_model.assert_called_with(ModelName=model_name)
+@pytest.mark.parametrize("boto_session", [""], indirect=True)
def test_user_agent_injected(boto_session):
- assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent
-
sess = Session(boto_session)
-
+ expected_user_agent_suffix = "lib/AWS-SageMaker-Python-SDK#2.218.0"
for client in [
sess.sagemaker_client,
sess.sagemaker_runtime_client,
sess.sagemaker_metrics_client,
]:
- assert SDK_PREFIX in client._client_config.user_agent
- assert NOTEBOOK_PREFIX not in client._client_config.user_agent
- assert STUDIO_PREFIX not in client._client_config.user_agent
+ assert expected_user_agent_suffix in client._client_config.user_agent
-@patch("sagemaker.user_agent.process_notebook_metadata_file", return_value="ml.t3.medium")
-def test_user_agent_injected_with_nbi(
- mock_process_notebook_metadata_file,
- boto_session,
-):
- assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent
-
- sess = Session(
- boto_session=boto_session,
+@pytest.mark.parametrize("boto_session", [f"{NOTEBOOK_SUFFIX}"], indirect=True)
+def test_user_agent_with_notebook_instance_type(boto_session):
+ sess = Session(boto_session)
+ expected_user_agent_suffix = (
+ "lib/AWS-SageMaker-Python-SDK#2.218.0 md/AWS-SageMaker-Notebook-Instance#instance_type"
)
-
for client in [
sess.sagemaker_client,
sess.sagemaker_runtime_client,
sess.sagemaker_metrics_client,
]:
- mock_process_notebook_metadata_file.assert_called()
-
- assert SDK_PREFIX in client._client_config.user_agent
- assert NOTEBOOK_PREFIX in client._client_config.user_agent
- assert STUDIO_PREFIX not in client._client_config.user_agent
+ assert expected_user_agent_suffix in client._client_config.user_agent
-@patch("sagemaker.user_agent.process_studio_metadata_file", return_value="dymmy-app-type")
-def test_user_agent_injected_with_studio_app_type(
- mock_process_studio_metadata_file,
- boto_session,
-):
- assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent
-
- sess = Session(
- boto_session=boto_session,
- )
-
+@pytest.mark.parametrize("boto_session", [f"{STUDIO_SUFFIX}"], indirect=True)
+def test_user_agent_with_studio_app_type(boto_session):
+ sess = Session(boto_session)
+ expected_user_agent = "lib/AWS-SageMaker-Python-SDK#2.218.0 md/AWS-SageMaker-Studio#app_type"
for client in [
sess.sagemaker_client,
sess.sagemaker_runtime_client,
sess.sagemaker_metrics_client,
]:
- mock_process_studio_metadata_file.assert_called()
-
- assert SDK_PREFIX in client._client_config.user_agent
- assert NOTEBOOK_PREFIX not in client._client_config.user_agent
- assert STUDIO_PREFIX in client._client_config.user_agent
+ assert expected_user_agent in client._client_config.user_agent
def test_training_input_all_defaults():
@@ -2197,6 +2177,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
CONTAINER_ENTRY_POINT = ["bin/bash", "test.sh"]
CONTAINER_ARGUMENTS = ["--arg1", "value1", "--arg2", "value2"]
remote_debug_config = {"EnableRemoteDebug": True}
+ session_chaining_config = {"EnableSessionTagChaining": True}
sagemaker_session.train(
image_uri=IMAGE,
@@ -2222,6 +2203,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
container_entry_point=CONTAINER_ENTRY_POINT,
container_arguments=CONTAINER_ARGUMENTS,
remote_debug_config=remote_debug_config,
+ session_chaining_config=session_chaining_config,
)
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
@@ -2245,6 +2227,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
)
assert actual_train_args["AlgorithmSpecification"]["ContainerArguments"] == CONTAINER_ARGUMENTS
assert actual_train_args["RemoteDebugConfig"]["EnableRemoteDebug"]
+ assert actual_train_args["SessionChainingConfig"]["EnableSessionTagChaining"]
def test_create_transform_job_with_sagemaker_config_injection(sagemaker_session):
diff --git a/tests/unit/test_user_agent.py b/tests/unit/test_user_agent.py
index c116fef951..fb46988e7b 100644
--- a/tests/unit/test_user_agent.py
+++ b/tests/unit/test_user_agent.py
@@ -13,20 +13,17 @@
from __future__ import absolute_import
import json
-from mock import MagicMock, patch, mock_open
+from mock import patch, mock_open
from sagemaker.user_agent import (
SDK_PREFIX,
SDK_VERSION,
- PYTHON_VERSION,
- OS_NAME_VERSION,
NOTEBOOK_PREFIX,
STUDIO_PREFIX,
process_notebook_metadata_file,
process_studio_metadata_file,
- determine_prefix,
- prepend_user_agent,
+ get_user_agent_extra_suffix,
)
@@ -60,45 +57,18 @@ def test_process_studio_metadata_file_not_exists(tmp_path):
assert process_studio_metadata_file() is None
-# Test determine_prefix function
-def test_determine_prefix_notebook_instance_type(monkeypatch):
- monkeypatch.setattr(
- "sagemaker.user_agent.process_notebook_metadata_file", lambda: "instance_type"
- )
- assert (
- determine_prefix()
- == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION} {NOTEBOOK_PREFIX}/instance_type"
- )
-
-
-def test_determine_prefix_studio_app_type(monkeypatch):
- monkeypatch.setattr(
- "sagemaker.user_agent.process_studio_metadata_file", lambda: "studio_app_type"
- )
- assert (
- determine_prefix()
- == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION} {STUDIO_PREFIX}/studio_app_type"
- )
-
-
-def test_determine_prefix_no_metadata(monkeypatch):
- monkeypatch.setattr("sagemaker.user_agent.process_notebook_metadata_file", lambda: None)
- monkeypatch.setattr("sagemaker.user_agent.process_studio_metadata_file", lambda: None)
- assert determine_prefix() == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION}"
-
-
-# Test prepend_user_agent function
-def test_prepend_user_agent_existing_user_agent(monkeypatch):
- client = MagicMock()
- client._client_config.user_agent = "existing_user_agent"
- monkeypatch.setattr("sagemaker.user_agent.determine_prefix", lambda _: "prefix")
- prepend_user_agent(client)
- assert client._client_config.user_agent == "prefix existing_user_agent"
-
-
-def test_prepend_user_agent_no_user_agent(monkeypatch):
- client = MagicMock()
- client._client_config.user_agent = None
- monkeypatch.setattr("sagemaker.user_agent.determine_prefix", lambda _: "prefix")
- prepend_user_agent(client)
- assert client._client_config.user_agent == "prefix"
+# Test get_user_agent_extra_suffix function
+def test_get_user_agent_extra_suffix():
+ assert get_user_agent_extra_suffix() == f"lib/{SDK_PREFIX}#{SDK_VERSION}"
+
+ with patch("sagemaker.user_agent.process_notebook_metadata_file", return_value="instance_type"):
+ assert (
+ get_user_agent_extra_suffix()
+ == f"lib/{SDK_PREFIX}#{SDK_VERSION} md/{NOTEBOOK_PREFIX}#instance_type"
+ )
+
+ with patch("sagemaker.user_agent.process_studio_metadata_file", return_value="studio_type"):
+ assert (
+ get_user_agent_extra_suffix()
+ == f"lib/{SDK_PREFIX}#{SDK_VERSION} md/{STUDIO_PREFIX}#studio_type"
+ )
diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py
index 81d8279e6d..d5214d01c3 100644
--- a/tests/unit/test_utils.py
+++ b/tests/unit/test_utils.py
@@ -30,6 +30,7 @@
from mock import call, patch, Mock, MagicMock, PropertyMock
import sagemaker
+from sagemaker.enums import RoutingStrategy
from sagemaker.experiments._run_context import _RunContext
from sagemaker.session_settings import SessionSettings
from sagemaker.utils import (
@@ -50,6 +51,9 @@
_is_bad_link,
custom_extractall_tarfile,
can_model_package_source_uri_autopopulate,
+ get_instance_rate_per_hour,
+ extract_instance_rate_per_hour,
+ _resolve_routing_config,
)
from tests.unit.sagemaker.workflow.helpers import CustomStep
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
@@ -1817,7 +1821,13 @@ def test_can_model_package_source_uri_autopopulate():
class TestDeepMergeDict(TestCase):
def test_flatten_dict_basic(self):
nested_dict = {"a": 1, "b": {"x": 2, "y": {"p": 3, "q": 4}}, "c": 5}
- flattened_dict = {"a": 1, "b.x": 2, "b.y.p": 3, "b.y.q": 4, "c": 5}
+ flattened_dict = {
+ ("a",): 1,
+ ("b", "x"): 2,
+ ("b", "y", "p"): 3,
+ ("b", "y", "q"): 4,
+ ("c",): 5,
+ }
self.assertDictEqual(flatten_dict(nested_dict), flattened_dict)
self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict)
@@ -1829,13 +1839,19 @@ def test_flatten_dict_empty(self):
def test_flatten_dict_no_nested(self):
nested_dict = {"a": 1, "b": 2, "c": 3}
- flattened_dict = {"a": 1, "b": 2, "c": 3}
+ flattened_dict = {("a",): 1, ("b",): 2, ("c",): 3}
self.assertDictEqual(flatten_dict(nested_dict), flattened_dict)
self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict)
def test_flatten_dict_with_various_types(self):
nested_dict = {"a": [1, 2, 3], "b": {"x": None, "y": {"p": [], "q": ""}}, "c": 9}
- flattened_dict = {"a": [1, 2, 3], "b.x": None, "b.y.p": [], "b.y.q": "", "c": 9}
+ flattened_dict = {
+ ("a",): [1, 2, 3],
+ ("b", "x"): None,
+ ("b", "y", "p"): [],
+ ("b", "y", "q"): "",
+ ("c",): 9,
+ }
self.assertDictEqual(flatten_dict(nested_dict), flattened_dict)
self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict)
@@ -1866,3 +1882,164 @@ def test_deep_override_skip_keys(self):
expected_result = {"a": 1, "b": {"x": 20, "y": 3, "z": 30}, "c": [4, 5]}
self.assertEqual(deep_override_dict(dict1, dict2, skip_keys=["c", "d"]), expected_result)
+
+
+@pytest.mark.parametrize(
+ "instance, region, amazon_sagemaker_price_result, expected",
+ [
+ (
+ "ml.t4g.nano",
+ "us-west-2",
+ {
+ "PriceList": [
+ {
+ "terms": {
+ "OnDemand": {
+ "3WK7G7WSYVS3K492.JRTCKXETXF": {
+ "priceDimensions": {
+ "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": {
+ "unit": "Hrs",
+ "endRange": "Inf",
+ "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour",
+ "appliesTo": [],
+ "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7",
+ "beginRange": "0",
+ "pricePerUnit": {"USD": "0.9000000000"},
+ }
+ }
+ }
+ }
+ },
+ }
+ ]
+ },
+ {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9"},
+ ),
+ (
+ "ml.t4g.nano",
+ "eu-central-1",
+ {
+ "PriceList": [
+ '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{'
+ '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": '
+ '"$0.0083 per'
+ "On"
+ 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": '
+ '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": '
+ '"0.0083000000"}}},'
+ '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",'
+ '"termAttributes": {}}}}}'
+ ]
+ },
+ {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"},
+ ),
+ (
+ "ml.t4g.nano",
+ "af-south-1",
+ {
+ "PriceList": [
+ '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{'
+ '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": '
+ '"$0.0083 per'
+ "On"
+ 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": '
+ '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": '
+ '"0.0083000000"}}},'
+ '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",'
+ '"termAttributes": {}}}}}'
+ ]
+ },
+ {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"},
+ ),
+ (
+ "ml.t4g.nano",
+ "ap-northeast-2",
+ {
+ "PriceList": [
+ '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{'
+ '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": '
+ '"$0.0083 per'
+ "On"
+ 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": '
+ '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": '
+ '"0.0083000000"}}},'
+ '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",'
+ '"termAttributes": {}}}}}'
+ ]
+ },
+ {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"},
+ ),
+ ],
+)
+@patch("boto3.client")
+def test_get_instance_rate_per_hour(
+ mock_client, instance, region, amazon_sagemaker_price_result, expected
+):
+
+ mock_client.return_value.get_products.side_effect = (
+ lambda *args, **kwargs: amazon_sagemaker_price_result
+ )
+ instance_rate = get_instance_rate_per_hour(instance_type=instance, region=region)
+
+ assert instance_rate == expected
+
+
+@pytest.mark.parametrize(
+ "price_data, expected_result",
+ [
+ (None, None),
+ (
+ {
+ "terms": {
+ "OnDemand": {
+ "3WK7G7WSYVS3K492.JRTCKXETXF": {
+ "priceDimensions": {
+ "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": {
+ "unit": "Hrs",
+ "endRange": "Inf",
+ "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour",
+ "appliesTo": [],
+ "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7",
+ "beginRange": "0",
+ "pricePerUnit": {"USD": "0.9000000000"},
+ }
+ }
+ }
+ }
+ }
+ },
+ {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9"},
+ ),
+ ],
+)
+def test_extract_instance_rate_per_hour(price_data, expected_result):
+ out = extract_instance_rate_per_hour(price_data)
+
+ assert out == expected_result
+
+
+@pytest.mark.parametrize(
+ "routing_config, expected",
+ [
+ ({"RoutingStrategy": RoutingStrategy.RANDOM}, {"RoutingStrategy": "RANDOM"}),
+ ({"RoutingStrategy": "RANDOM"}, {"RoutingStrategy": "RANDOM"}),
+ (
+ {"RoutingStrategy": RoutingStrategy.LEAST_OUTSTANDING_REQUESTS},
+ {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
+ ),
+ (
+ {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
+ {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
+ ),
+ ({"RoutingStrategy": None}, None),
+ (None, None),
+ ],
+)
+def test_resolve_routing_config(routing_config, expected):
+ res = _resolve_routing_config(routing_config)
+
+ assert res == expected
+
+
+def test_resolve_routing_config_ex():
+ pytest.raises(ValueError, lambda: _resolve_routing_config({"RoutingStrategy": "Invalid"}))
diff --git a/tox.ini b/tox.ini
index 718e968013..6e1f9ce956 100644
--- a/tox.ini
+++ b/tox.ini
@@ -81,7 +81,7 @@ passenv =
# Can be used to specify which tests to run, e.g.: tox -- -s
commands =
python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')"
- pip install 'apache-airflow==2.9.0' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.0/constraints-3.8.txt"
+ pip install 'apache-airflow==2.9.1' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.1/constraints-3.8.txt"
pip install 'torch==2.0.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html'
pip install 'torchvision==0.15.2+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html'
pip install 'dill>=0.3.8'