diff --git a/.githooks/pre-push b/.githooks/pre-push index 995ab70108..f73fa492b3 100755 --- a/.githooks/pre-push +++ b/.githooks/pre-push @@ -12,5 +12,5 @@ start_time=`date +%s` tox -e sphinx,doc8 --parallel all ./ci-scripts/displaytime.sh 'sphinx,doc8' $start_time start_time=`date +%s` -tox -e py38,py39,py310 --parallel all -- tests/unit -./ci-scripts/displaytime.sh 'py38,py39,py310 unit' $start_time +tox -e py39,py310,py311,py312 --parallel all -- tests/unit +./ci-scripts/displaytime.sh 'py39,py310,py311,py312 unit' $start_time diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 6240c697f8..e659c40513 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -22,5 +22,6 @@ _Put an `x` in the boxes that apply. You can also fill these out after creating - [ ] I have added unit and/or integration tests as appropriate to ensure backward compatibility of the changes - [ ] I have checked that my tests are not configured for a specific region or account (if appropriate) - [ ] I have used [`unique_name_from_base`](https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/utils.py#L77) to create resource names in integ tests (if appropriate) +- [ ] If adding any dependency in requirements.txt files, I have spell checked and ensured they exist in PyPi By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. diff --git a/.github/workflows/codebuild-canaries.yml b/.github/workflows/codebuild-canaries.yml new file mode 100644 index 0000000000..a6b5a978ef --- /dev/null +++ b/.github/workflows/codebuild-canaries.yml @@ -0,0 +1,24 @@ +name: Canaries +on: + schedule: + - cron: "0 */3 * * *" + workflow_dispatch: + +permissions: + id-token: write # This is required for requesting the JWT + +jobs: + tests: + runs-on: ubuntu-latest + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} + aws-region: us-west-2 + role-duration-seconds: 10800 + - name: Run Integ Tests + uses: aws-actions/aws-codebuild-run-build@v1 + id: codebuild + with: + project-name: sagemaker-python-sdk-canaries diff --git a/.github/workflows/codebuild-ci-health.yml b/.github/workflows/codebuild-ci-health.yml index 7ecefd310f..119b9dbe9c 100644 --- a/.github/workflows/codebuild-ci-health.yml +++ b/.github/workflows/codebuild-ci-health.yml @@ -26,7 +26,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["py38", "py39", "py310", "py311"] + python-version: ["py39", "py310", "py311","py312"] steps: - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v4 diff --git a/.github/workflows/codebuild-ci.yml b/.github/workflows/codebuild-ci.yml index 85919f0afe..eef53ff06c 100644 --- a/.github/workflows/codebuild-ci.yml +++ b/.github/workflows/codebuild-ci.yml @@ -55,7 +55,7 @@ jobs: - name: Run Codestyle & Doc Tests uses: aws-actions/aws-codebuild-run-build@v1 with: - project-name: sagemaker-python-sdk-ci-codestyle-doc-tests + project-name: ${{ github.event.repository.name }}-ci-codestyle-doc-tests source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}' unit-tests: runs-on: ubuntu-latest @@ -63,7 +63,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["py38","py39","py310","py311"] + python-version: ["py39","py310","py311","py312"] steps: - name: Configure AWS Credentials uses: aws-actions/configure-aws-credentials@v4 @@ -74,7 +74,7 @@ jobs: - name: Run Unit Tests uses: aws-actions/aws-codebuild-run-build@v1 with: - project-name: sagemaker-python-sdk-ci-unit-tests + project-name: ${{ github.event.repository.name }}-ci-unit-tests source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}' env-vars-for-codebuild: | PY_VERSION @@ -93,5 +93,5 @@ jobs: - name: Run Integ Tests uses: aws-actions/aws-codebuild-run-build@v1 with: - project-name: sagemaker-python-sdk-ci-integ-tests + project-name: ${{ github.event.repository.name }}-ci-integ-tests source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}' diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000000..8fbf42803b --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,35 @@ +name: "CodeQL" +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + schedule: + - cron: '30 15 * * *' +jobs: + analyze: + name: Analyze (${{ matrix.language }}) + runs-on: ${{ 'ubuntu-latest' }} + permissions: + security-events: write + packages: read + + strategy: + matrix: + include: + - language: python + build-mode: none + - language: java-kotlin + build-mode: none + steps: + - name: Checkout repository + uses: actions/checkout@6ccd57f4c5d15bdc2fef309bd9fb6cc9db2ef1c6 + - name: Initialize CodeQL + uses: github/codeql-action/init@4b1d7da102ff94aca014c0245062b1a463356d72 + with: + languages: ${{ matrix.language }} + build-mode: ${{ matrix.build-mode }} + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@4b1d7da102ff94aca014c0245062b1a463356d72 + with: + category: "/language:${{matrix.language}}" diff --git a/.github/workflows/security-monitoring.yml b/.github/workflows/security-monitoring.yml new file mode 100644 index 0000000000..ecce0643e6 --- /dev/null +++ b/.github/workflows/security-monitoring.yml @@ -0,0 +1,121 @@ +name: Security Monitoring + +on: + schedule: + - cron: '0 16 * * *' + +concurrency: + group: ${{ github.workflow }}-${{ github.run_id }} + cancel-in-progress: true + +permissions: + id-token: write + +jobs: + check-code-scanning-alerts: + runs-on: ubuntu-latest + outputs: + code_scanning_alert_status: ${{ steps.check-code-scanning-alerts.outputs.code_scanning_alert_status }} + steps: + - name: Check for security alerts + id: check-code-scanning-alerts + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea + with: + github-token: ${{ secrets.GH_PAT }} + script: | + async function checkAlerts() { + const owner = '${{ github.repository_owner }}'; + const repo = '${{ github.event.repository.name }}'; + const ref = 'refs/heads/master'; + + const codeScanningAlerts = await github.rest.codeScanning.listAlertsForRepo({ + owner, + repo, + ref: ref + }); + const activeCodeScanningAlerts = codeScanningAlerts.data.filter(alert => alert.state === 'open'); + core.setOutput('code_scanning_alert_status', activeCodeScanningAlerts.length > 0 ? '1': '0'); + } + await checkAlerts(); + + check-dependabot-alerts: + runs-on: ubuntu-latest + outputs: + dependabot_alert_status: ${{ steps.check-dependabot-alerts.outputs.dependabot_alert_status }} + steps: + - name: Check for dependabot alerts + id: check-dependabot-alerts + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea + with: + github-token: ${{ secrets.GH_PAT }} + script: | + async function checkAlerts() { + const owner = '${{ github.repository_owner }}'; + const repo = '${{ github.event.repository.name }}'; + + const dependabotAlerts = await github.rest.dependabot.listAlertsForRepo({ + owner, + repo, + headers: { + 'accept': 'applications/vnd.github+json' + } + }); + const activeDependabotAlerts = dependabotAlerts.data.filter(alert => alert.state === 'open'); + core.setOutput('dependabot_alert_status', activeDependabotAlerts.length > 0 ? '1': '0'); + } + await checkAlerts(); + + check-secret-scanning-alerts: + runs-on: ubuntu-latest + outputs: + secret_scanning_alert_status: ${{ steps.check-secret-scanning-alerts.outputs.secret_scanning_alert_status }} + steps: + - name: Check for secret scanning alerts + id: check-secret-scanning-alerts + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea + with: + github-token: ${{ secrets.GH_PAT }} + script: | + async function checkAlerts() { + const owner = '${{ github.repository_owner }}'; + const repo = '${{ github.event.repository.name }}'; + + const secretScanningAlerts = await github.rest.secretScanning.listAlertsForRepo({ + owner, + repo, + }); + const activeSecretScanningAlerts = secretScanningAlerts.data.filter(alert => alert.state === 'open'); + core.setOutput('secret_scanning_alert_status', activeSecretScanningAlerts.length > 0 ? '1': '0'); + } + await checkAlerts(); + + put-metric-data: + runs-on: ubuntu-latest + needs: [check-code-scanning-alerts, check-dependabot-alerts, check-secret-scanning-alerts] + steps: + - name: Configure AWS Credentials + uses: aws-actions/configure-aws-credentials@12e3392609eaaceb7ae6191b3f54bbcb85b5002b + with: + role-to-assume: ${{ secrets.MONITORING_ROLE_ARN }} + aws-region: us-west-2 + - name: Put Code Scanning Alert Metric Data + run: | + if [ "${{ needs.check-code-scanning-alerts.outputs.code_scanning_alert_status }}" == "1" ]; then + aws cloudwatch put-metric-data --metric-name CodeScanningAlert --namespace SecurityMonitoringMetrics --value 1 --unit Count --dimensions ProjectName=sagemaker-python-sdk + else + aws cloudwatch put-metric-data --metric-name CodeScanningAlert --namespace SecurityMonitoringMetrics --value 0 --unit Count --dimensions ProjectName=sagemaker-python-sdk + fi + - name: Put Dependabot Alert Metric Data + run: | + if [ "${{ needs.check-dependabot-alerts.outputs.dependabot_alert_status }}" == "1" ]; then + aws cloudwatch put-metric-data --metric-name DependabotAlert --namespace SecurityMonitoringMetrics --value 1 --unit Count --dimensions ProjectName=sagemaker-python-sdk + else + aws cloudwatch put-metric-data --metric-name DependabotAlert --namespace SecurityMonitoringMetrics --value 0 --unit Count --dimensions ProjectName=sagemaker-python-sdk + fi + - name: Put Secret Scanning Alert Metric Data + run: | + if [ "${{ needs.check-secret-scanning-alerts.outputs.secret_scanning_alert_status }}" == "1" ]; then + aws cloudwatch put-metric-data --metric-name SecretScanningAlert --namespace SecurityMonitoringMetrics --value 1 --unit Count --dimensions ProjectName=sagemaker-python-sdk + else + aws cloudwatch put-metric-data --metric-name SecretScanningAlert --namespace SecurityMonitoringMetrics --value 0 --unit Count --dimensions ProjectName=sagemaker-python-sdk + fi diff --git a/.gitignore b/.gitignore index ad6e488dbd..3d90b52e01 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,10 @@ env/ .python-version *.html **/_repack_script_launcher.sh +src/sagemaker/modules/train/container_drivers/sm_train.sh +src/sagemaker/modules/train/container_drivers/sourcecode.json +src/sagemaker/modules/train/container_drivers/distributed.json tests/data/**/_repack_model.py tests/data/experiment/sagemaker-dev-1.0.tar.gz -src/sagemaker/serve/tmp_workspace \ No newline at end of file +src/sagemaker/serve/tmp_workspace +test-examples \ No newline at end of file diff --git a/.pydocstylerc b/.pydocstylerc index a5083c0d63..9ed879a760 100644 --- a/.pydocstylerc +++ b/.pydocstylerc @@ -2,3 +2,4 @@ inherit = false ignore = D104,D107,D202,D203,D213,D214,D400,D401,D404,D406,D407,D411,D413,D414,D415,D417 match = (?!record_pb2).*\.py +match-dir = (?!.*test).* \ No newline at end of file diff --git a/.pylintrc b/.pylintrc index 11e2ababa9..223580f4d3 100644 --- a/.pylintrc +++ b/.pylintrc @@ -94,7 +94,24 @@ disable= useless-object-inheritance, # TODO: Enable this check and fix code once Python 2 is no longer supported. super-with-arguments, raise-missing-from, - E1136, + C0116, # Missing function or method docstring + C0209, # Use f-string instead of format + E0015, # Unrecognized option found in config + E0702, # Raising a string instead of an exception + E1101, # Module has no member (likely dynamic attr) + E1136, # Value assigned to something inferred as None + R0022, # Useless option value in config + R1710, # Inconsistent return statements + R1714, # Consider using `in` with comparisons + R1729, # Use a generator + R1732, + R1735, # Consider using a dict or list literal + W0237, # Argument renamed in override + W0613, # Unused argument + W0621, # Redefining name from outer scope + W0719 + W1404, # Implicit string concatenation + W1514, # `open()` used without encoding [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs @@ -310,7 +327,7 @@ ignore-mixin-members=yes # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. -ignored-modules=distutils +ignored-modules= # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of @@ -384,7 +401,7 @@ max-returns=6 max-branches=12 # Maximum number of statements in function / method body -max-statements=100 +max-statements=105 # Maximum number of parents for a class (see R0901). max-parents=7 @@ -436,4 +453,4 @@ analyse-fallback-blocks=no # Exceptions that will emit a warning when being caught. Defaults to # "Exception" -overgeneral-exceptions=Exception +overgeneral-exceptions=builtins.Exception diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 0a6e3928b5..0dcc70b9c3 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -5,9 +5,9 @@ version: 2 build: - os: ubuntu-20.04 + os: ubuntu-22.04 tools: - python: "3.9" + python: "3.12" python: diff --git a/CHANGELOG.md b/CHANGELOG.md index 880e5df8c5..37c1d155cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,864 @@ # Changelog +## v2.251.0 (2025-08-21) + +### Features + + * support pipeline versioning + +### Bug Fixes and Other Changes + + * GPT OSS Hotfix + * dockerfile stuck on interactive shell + * add sleep for model deployment + +## v2.250.0 (2025-08-08) + +### Features + + * Add support for InstancePlacementConfig in Estimator for training jobs running on ultraserver capacity + +### Bug Fixes and Other Changes + + * Add more constraints to test requirements + +## v2.249.0 (2025-07-31) + +### Features + + * AWS Batch for SageMaker Training jobs + +### Bug Fixes and Other Changes + + * Directly use customer-provided endpoint name for ModelBuilder deployment. + * update image_uri_configs 07-23-2025 07:18:25 PST + +## v2.248.2 (2025-07-22) + +### Bug Fixes and Other Changes + + * Relax boto3 version requirement + * update image_uri_configs 07-22-2025 07:18:25 PST + * update image_uri_configs 07-18-2025 07:18:28 PST + * add hard dependency on sagemaker-core pypi lib + * When rootlessDocker is enabled, return a fixed SageMaker IP + +## v2.248.1 (2025-07-16) + +### Bug Fixes and Other Changes + + * Nova training support + +## v2.248.0 (2025-07-15) + +### Features + + * integrate amtviz for visualization of tuning jobs + +### Bug Fixes and Other Changes + + * build(deps): bump requests in /tests/data/serve_resources/mlflow/pytorch + * build(deps): bump protobuf from 4.25.5 to 4.25.8 in /requirements/extras + * build(deps): bump mlflow in /tests/data/serve_resources/mlflow/xgboost + * build(deps): bump torch in /tests/data/modules/script_mode + * sanitize git clone repo input url + * Adding Hyperpod feature to enable hyperpod telemetry + * Adding Hyperpod feature to enable hyperpod telemetry + * Bump SMD version to enable custom workflow deployment. + * Update TF DLC python version to py312 + * update image_uri_configs 07-04-2025 07:18:27 PST + * update image_uri_configs 06-26-2025 07:18:35 PST + * relax protobuf to <6.32 + +## v2.247.1 (2025-06-23) + +### Bug Fixes and Other Changes + + * update image_uri_configs 06-19-2025 07:18:34 PST + +## v2.247.0 (2025-06-13) + +### Features + + * Add support for MetricDefinitions in ModelTrainer + +### Bug Fixes and Other Changes + + * update jumpstart region_config, update image_uri_configs 06-12-2025 07:18:12 PST + * Add ignore_patterns in ModelTrainer to ignore specific files/folders + * Allow import failure for internal _hashlib module + +## v2.246.0 (2025-06-04) + +### Features + + * Triton v25.04 DLC + +### Bug Fixes and Other Changes + + * Update Attrs version to widen support + * update estimator documentation regarding hyperparameters for source_dir + +## v2.245.0 (2025-05-28) + +### Features + + * Correct mypy type checking through PEP 561 + +### Bug Fixes and Other Changes + + * MLFLow update for dependabot + * addWaiterTimeoutHandling + * merge method inputs with class inputs + * update image_uri_configs 05-20-2025 07:18:17 PST + +## v2.244.2 (2025-05-19) + +### Bug Fixes and Other Changes + + * include model channel for gated uncompressed models + * clarify model monitor one time schedule bug + * update jumpstart region_config 05-15-2025 07:18:15 PST + * update image_uri_configs 05-14-2025 07:18:16 PST + * Add image configs and region config for TPE (ap-east-2) + * Improve defaults handling in ModelTrainer + +## v2.244.1 (2025-05-15) + +### Bug Fixes and Other Changes + + * Fix Flask-Limiter version + * Fix test_huggingface_tei_uris() + * huggingface-llm-neuronx dlc + * huggingface-neuronx dlc image_uri + * huggingface-tei dlc image_uri + * Fix test_deploy_with_update_endpoint() + * add AG v1.3 + * parameter mismatch in update_endpoint + * remove --strip-component for untar source tar.gz + * Fix type annotations + * chore: Allow omegaconf >=2.2,<3 + * honor json serialization of HPs + * Map llama models to correct script + * pin test dependency + * fix bad initialization script error message + * Improve error logging and documentation for issue 4007 + * build(deps): bump scikit-learn + * build(deps): bump mlflow + * build(deps): bump mlflow in /tests/data/serve_resources/mlflow/pytorch + * chore: Add tei 1.6.0 image + +## v2.244.0 (2025-05-02) + +### Features + + * support custom workflow deployment in ModelBuilder using SMD image. + +### Bug Fixes and Other Changes + + * Add Owner ID check for bucket with path when prefix is provided + * Add model server timeout + * pin mamba version to 24.11.3-2 to avoid inconsistent test runs + * Update ModelTrainer to support s3 uri and tar.gz file as source_dir + * chore: add huggingface images + +## v2.243.3 (2025-04-23) + +### Bug Fixes and Other Changes + + * update readme to reflect py312 upgrade + * Revert the PR changes 5122 + * Py312 upgrade step 2: Update dependencies, integ tests and unit tests + * update pr test to deprecate py38 and add py312 + * update image_uri_configs 04-16-2025 07:18:18 PST + * update image_uri_configs 04-15-2025 07:18:10 PST + * update image_uri_configs 04-11-2025 07:18:19 PST + +## v2.243.2 (2025-04-16) + +### Bug Fixes and Other Changes + + * tgi image uri unit tests + * Fix deepdiff dependencies + +## v2.243.1 (2025-04-11) + +### Bug Fixes and Other Changes + + * Added handler for pipeline variable while creating process job + * Fix issue #4856 by copying environment variables + * remove historical job_name caching which causes long job name + * Update instance gpu info + * Master + * Add mlflow tracking arn telemetry + * chore: fix semantic versioning for wildcard identifier + * flaky test + +### Documentation Changes + + * update pipelines step caching examples to include more steps + * update ModelStep data dependency info + +## v2.243.0 (2025-03-27) + +### Features + + * Enabled update_endpoint through model_builder + +### Bug Fixes and Other Changes + + * Update for PT 2.5.1, SMP 2.8.0 + * chore: move jumpstart region definitions to json file + * fix flaky clarify model monitor test + * fix flaky spark processor integ + * use temp file in unit tests + * Update transformers version + * Aligned disable_output_compression for @remote with Estimator + * Update Jinja version + * update image_uri_configs 03-26-2025 07:18:16 PST + * chore: fix integ tests to use latest version of model + * update image_uri_configs 03-25-2025 07:18:13 PST + * Skip tests failed due to deprecated instance type + * update image_uri_configs 03-21-2025 07:17:55 PST + * factor in set instance type when building JumpStart models in ModelBuilder. + * ADD Documentation to ReadtheDocs for Upgrading torch versions + * add new regions to JUMPSTART_LAUNCHED_REGIONS + +## v2.242.0 (2025-03-14) + +### Features + + * add integ tests for training JumpStart models in private hub + +### Bug Fixes and Other Changes + + * Torch upgrade + * Prevent RunContext overlap between test_run tests + * remove s3 output location requirement from hub class init + * Fixing Pytorch training python version in tests + * update image_uri_configs 03-11-2025 07:18:09 PST + * resolve infinite loop in _find_config on Windows systems + * pipeline definition function doc update + +## v2.241.0 (2025-03-06) + +### Features + + * Make DistributedConfig Extensible + * support training for JumpStart model references as part of Curated Hub Phase 2 + * Allow ModelTrainer to accept hyperparameters file + +### Bug Fixes and Other Changes + + * Skip tests with deprecated instance type + * Ensure Model.is_repack() returns a boolean + * Fix error when there is no session to call _create_model_request() + * Use sagemaker session's s3_resource in download_folder + * Added check for the presence of model package group before creating one + * Fix key error in _send_metrics() + +## v2.240.0 (2025-02-25) + +### Features + + * Add support for TGI Neuronx 0.0.27 and HF PT 2.3.0 image in PySDK + +### Bug Fixes and Other Changes + + * Remove main function entrypoint in ModelBuilder dependency manager. + * forbid extras in Configs + * altconfig hubcontent and reenable integ test + * Merge branch 'master-rba' into local_merge + * py_version doc fixes + * Add backward compatbility for RecordSerializer and RecordDeserializer + * update image_uri_configs 02-21-2025 06:18:10 PST + * update image_uri_configs 02-20-2025 06:18:08 PST + +### Documentation Changes + + * Removed a line about python version requirements of training script which can misguide users. + +## v2.239.3 (2025-02-19) + +### Bug Fixes and Other Changes + + * added ap-southeast-7 and mx-central-1 for Jumpstart + * update image_uri_configs 02-19-2025 06:18:15 PST + +## v2.239.2 (2025-02-18) + +### Bug Fixes and Other Changes + + * Add warning about not supporting torch.nn.SyncBatchNorm + * pass in inference_ami_version to model_based endpoint type + * Fix hyperparameter strategy docs + * Add framework_version to all TensorFlowModel examples + * Move RecordSerializer and RecordDeserializer to sagemaker.serializers and sagemaker.deserialzers + +## v2.239.1 (2025-02-14) + +### Bug Fixes and Other Changes + + * keep sagemaker_session from being overridden to None + * Fix all type hint and docstrings for callable + * Fix the workshop link for Step Functions + * Fix Tensorflow doc link + * Fix FeatureGroup docstring + * Add type hint for ProcessingOutput + * Fix sourcedir.tar.gz filenames in docstrings + * Fix documentation for local mode + * bug in get latest version was getting the max sorted alphabetically + * Add cleanup logic to model builder integ tests for endpoints + * Fixed pagination failing while listing collections + * fix ValueError when updating a data quality monitoring schedule + * Add docstring for image_uris.retrieve + * Create GitHub action to trigger canaries + * update image_uri_configs 02-04-2025 06:18:00 PST + +## v2.239.0 (2025-02-01) + +### Features + + * Add support for deepseek recipes + +### Bug Fixes and Other Changes + + * mpirun protocol - distributed training with @remote decorator + * Allow telemetry only in supported regions + * Fix ssh host policy + +## v2.238.0 (2025-01-29) + +### Features + + * use jumpstart deployment config image as default optimization image + +### Bug Fixes and Other Changes + + * chore: add new images for HF TGI + * update image_uri_configs 01-29-2025 06:18:08 PST + * skip TF tests for unsupported versions + * Merge branch 'master-rba' into local_merge + * Add missing attributes to local resourceconfig + * update image_uri_configs 01-27-2025 06:18:13 PST + * update image_uri_configs 01-24-2025 06:18:11 PST + * add missing schema definition in docs + * Omegaconf upgrade + * SageMaker @remote function: Added multi-node functionality + * remove option + * fix typo + * fix tests + * Add an option for user to remove inputs and container artifacts when using local model trainer + +## v2.237.3 (2025-01-09) + +### Bug Fixes and Other Changes + + * pin metadata-version to 2.3 + * model server might have already done a serialization. honor that by not decoding the request again if it is not already bytes or bytestream + * Disable jumpstart tests missing clean up logic + * Jumpstart ap southeast 5 + * add autogluon 1.2 + * updated inference script to cover context + * security update -> use sha256 instead of md5 for file hashing + * Fix Flake8 Violations + * Added parsing string support for situations where custom code might be used (ie. mlflow) + * Updating Inference Optimization Validations + +## v2.237.2 (2024-12-17) + +### Bug Fixes and Other Changes + + * update image_uri_configs 12-13-2024 17:07:12 PST + * Cloudpickle upgrade + +## v2.237.1 (2024-12-12) + +### Bug Fixes and Other Changes + + * chore: remove support for ecr spec fallbacks for jumpstart models + * Cloudpickle Revert + * Cloudpickle update + * Numpy update + * Protobuf update + * Update to fetch latest Cloudpickle version + +## v2.237.0 (2024-12-05) + +### Features + + * Support SageMakerTrainingPlan for training jobs + * AMI support for BRM + * Adding Bedrock Store model support for HubService + +### Bug Fixes and Other Changes + + * Fix unit tests + * update boto3 and sagemaker-core version + * fix gpu_image uri + * Hotfix to construct rubik uri correctly + * fix codestyles + * fix merge artifact + * fix merge artifact + * fix test_requiremenets.txt + * chore: Merge from main + +## v2.236.0 (2024-12-04) + +### Features + + * Partner App Auth Provider for SDK support + * add pre-processing and post-processing logic to inference_spec + * add utility function to capture local snapshot + * support script mode with local train.sh + +### Bug Fixes and Other Changes + + * Add graphene to doc requirements + * Add graphne to the doc requirements + * Enable the Recipe tests marked with @pytest.mark.skip(reason="Hyperpod recipe code unavailable" + * Add model trainer documentation + * Usage docs for training recipes + * Neuron URIs update + * Update URIs to public for training recipes + * Changes for SMP v2.7.0 + * Change default source directory to current, add option to specify source dir + * Remove default values for fields in recipe_overrides and fix recipe path. + * Update MANIFEST.in so that wheel builds correctly + * fix the file uploading signature verification error + * remove example notebooks artifacts + * Morpheus tests + * Integ tests for local mode model trainer + * Update hyperpod recipe uris + * Add interface units for ModelTrainer + * Model Trainer Bucket improvements + * Update ModelTrainer Interface Parameters + * add in-process mode definition to docs + * Intelligent defaults for Model Trainer + * Fix tests and codestyle + * add integ test for base_model_builder_deploy and remove print statement + * Revert image builder + * pin xgboost dlc to 1.7.1 to fix test + * Skip JS model mapping with env vars or image URI provided + * Use sagemaker core Session + * Integration tests for Model Builder Handshake + * [Updated] Add telemetry to ModelTrainer, Estimator and ModelBuilder + * Update kandinsky in ModelTrainer and allow setting requirements + * add modelID support to model builder InProcess model + * Add Rich Logging to Model Builder + * Notebooks update for Bugbash + * Add bugbash bootstrapping + * add inference morpheus nbs + * Update ModelTrainer Notebooks + * Bug fixes + * Single container local training + * update notebooks + * update notebooks + * Add recipes examples + * Unified Deployment interface in Model Builder + * Use exact python path in trainer template + * Support building image from Dockerfile + * Add Support for Training Recipes + * Trainer handshake + * Pass hyperparameters as CLI args + * Add in_process mode support for DJL and TorchServe servers + * Remove ignored files + * Simplify Config Class Names and DistributedRunner structures + * Fix bug in script mode setup ModelTrainer + * Mask Sensitive Env Logs in Container + * Add path to set Additional Settings in ModelTrainer + * Add Distributed Training Support Model Trainer + * Cleanup ModelTrainer code + * Latest Container Image + * General image builder + * Cleanup ModelTrainer + * Revert Image Spec + * Support intelligent parameters + * Add enviornment variable bootstrapping script + * Add example notebook + * Add unit tests for ModelTrainer + * Image Spec refactoring and updates + * Base model trainer + +## v2.235.2 (2024-11-22) + +## v2.235.1 (2024-11-20) + +### Bug Fixes and Other Changes + + * Update sagemaker-core dep + * update image_uri_configs 11-20-2024 06:17:41 PST + +## v2.235.0 (2024-11-19) + +### Features + + * Optimize() validations across TRT, VLLM, Neuron container optimizations + +### Bug Fixes and Other Changes + + * update image_uri_configs 11-19-2024 06:17:58 PST + +## v2.234.0 (2024-11-19) + +### Features + + * optimization technique related validations. + +### Bug Fixes and Other Changes + + * Revert "change: add TGI 2.4.0 image uri (#4922)" + * pin testing deps + * add TGI 2.4.0 image uri + * add jumpstart ap-southeast-5 + * Move sagemaker-mlflow to extras + +## v2.233.0 (2024-11-04) + +### Features + + * triton v24.09 + * Marketplace model support in HubService + +### Bug Fixes and Other Changes + + * Fixing JumpStart Tests + * bumping smp version from 2.6.0 to 2.6.1 + * Updates for DJL 0.30.0 release + +## v2.232.3 (2024-10-30) + +### Bug Fixes and Other Changes + + * update image_uri_configs 10-29-2024 07:17:56 PST + * Skip pytorch tests incompatible with latest version 2.4.0 + * adding eu-central-2 bucket info to JS constants + * update image_uri_configs 10-23-2024 11:26:03 PST + * update image_uri_configs 10-17-2024 07:17:55 PST + * update image_uri_configs 10-03-2024 07:17:59 PST + * update image_uri_configs 09-27-2024 07:18:01 PST + * modified pull request template + * fixing typo in dependecy setup + * release: huggingface tgi neuronx 0.0.25 image + * Revert "update cloudpickle version to >=2.2.1 in pyproject.toml (#4899)" + * update cloudpickle version to >=2.2.1 in pyproject.toml + * update cloudpickle version to >=2.2.1 + * chore(deps): bump pyspark from 3.3.1 to 3.3.2 in /requirements/extras + * changes for PT 2.4 currency upgrade + * chore: add lmi image config in me-central-1 + * tests: Implement integration tests covering JumpStart PrivateHub workflows + * Use Miniforge to replace MambaForge + +## v2.232.2 (2024-10-03) + +### Bug Fixes and Other Changes + + * Pass kwargs to HuggingFaceModel.deploy() + * improve logging and exception messages + * remove deprecated distutils + * update image_uri_configs 09-24-2024 07:18:00 PST + +## v2.232.1 (2024-09-19) + +### Bug Fixes and Other Changes + + * update image_uri_configs 09-17-2024 07:17:54 PST + * support latest container version in image_uris and DJLModel for lmi c… + +## v2.232.0 (2024-09-12) + +### Features + + * add deployment config name in modelbuilder telemetry + * add Clarify image URIs for us-isof + +### Bug Fixes and Other Changes + + * chore: add flaky test markers & skip region with low P3 instance capacity + * update image_uri_configs 09-11-2024 11:54:11 PST + * update image_uri_configs 09-10-2024 07:18:01 PST + * [change] add us-gov and cn region repo accounts to djl and hugging face image metadata + * update image_uri_configs 09-06-2024 07:17:55 PST + * add us-gov region repo accounts to djl image metadata + * pass name from modelbuilder constructor to created model + +## v2.231.0 (2024-08-30) + +### Features + + * Add SageMaker Core to the dependency + +### Bug Fixes and Other Changes + + * Disable test_mnist_async + * SMP v2.5 + * update image_uri_configs 08-29-2024 07:17:59 PST + +## v2.230.0 (2024-08-28) + +### Features + + * FastAPI integration for In_Process Mode (2/2) + +### Bug Fixes and Other Changes + + * chore: add HF LLM neuronx 0.0.24 image + * TF-2.16 test modification and handling + * fix test fail + * Add troubleshooting links to exceptions + * cross account private hub model fine-tuning + * chore: cleanup jumpstart factory + * disable failing integration tests + +## v2.229.0 (2024-08-15) + +### Features + + * Support for ModelBuilder In_Process Mode (1/2) + * Pulling in dependencies (in_process mode) using conda environment + * Add optional CodeArtifact login to FrameworkProcessing job script + * implemented security-monitoring to send metrics to CW #1510 + +### Bug Fixes and Other Changes + + * alt configs model deployment and training issues + * fix keras extension in integ test + * update image_uri_configs 08-13-2024 07:17:54 PST + * trn1 instance family does not support volume size + * Update model.py + * removed log statement + * update image_uri_configs 08-09-2024 07:18:00 PST + * Added torchrun compatibility for distributet training across multiple GPUs in a single node (single instance) + * BiasConfig type hint + * add model monitor image accounts for ap-southeast-5 and eu-central-2 + * aligned UTC times with PST + * ensure hpt jobs inherit tags from config + * add JumpStart PDT and OSU regions + * chore(deps): bump certifi in /src/sagemaker/serve/utils + * Updates for DJL 0.29.0 release + * chore(deps): bump apache-airflow from 2.9.2 to 2.9.3 in /requirements/extras + * chore(deps): bump torch from 2.0.1 to 2.2.0 in /tests/data/serve_resources/mlflow/pytorch + * avoided printing stack trace and escaped input + * removing kwargs as this is breaking predictor_cls param for mode… + +## v2.228.0 (2024-08-06) + +### Features + + * triton v24.05 + +### Bug Fixes and Other Changes + + * chore: telemetry for deployment configs + * censoring sensitive values from being logged + * update image_uri_configs 08-05-2024 07:17:38 PST + * enable uncompressed model artifacts upload to S3 for SAGEMAKER_ENDPOINT overwrite for TGI, TEI, MMS model servers + * ModelReference deployment for Alt Configs models + * Add optional typecheck for nullable parameters + * Update package metadata + * release TEI 1.4.0 + +## v2.227.0 (2024-07-30) + +### Features + + * added code scanning through CodeQL + +### Bug Fixes and Other Changes + + * Fixed cpu isntance type for the estimator register test + * update image_uri_configs 07-29-2024 11:28:28 PST + * avoid AccessDenied error for a while on SageMaker Studio wtih do… + * SMP PT 2.3 Fix + * chore: pin framework version in serverless inference tests + * image uri in TGI 2.2.0 image + * explicitly access enum member values to avoid Python version related regression + * chore: add huggingface TGI 2.2.0 config + * update image_uri_configs 07-22-2024 11:53:54 PST + * update image_uri_configs 07-17-2024 07:17:38 PST + * update image_uri_configs 07-16-2024 07:17:45 PST + * add support for new regions + +## v2.226.1 (2024-07-17) + +## v2.226.0 (2024-07-12) + +### Features + + * Curated hub improvements + * InferenceSpec support for MMS and testing + +### Bug Fixes and Other Changes + + * ModelBuilder not passing HF_TOKEN to model. + * update image_uri_configs 07-10-2024 07:18:04 PST + +## v2.225.0 (2024-07-10) + +### Features + + * model optimization + +### Bug Fixes and Other Changes + + * fix integ test + * update uris for v1.1.1 + * update image_uri_configs 07-04-2024 07:17:24 PST + +## v2.224.4 (2024-07-04) + +### Bug Fixes and Other Changes + + * allow for inf spec and server override to be passed + +## v2.224.3 (2024-07-03) + +### Bug Fixes and Other Changes + + * Upgrade local dependencies + * Improve docstrings for estimator tags + +## v2.224.2 (2024-06-27) + +### Bug Fixes and Other Changes + + * Update DJLModel class for latest container releases + * list_models() for python3.8 + +## v2.224.1 (2024-06-21) + +### Bug Fixes and Other Changes + + * JumpStart CuratedHub Launch + * Update README.rst to show conda-forge version of SageMaker SDK + * Update tox.ini + * chore(deps): bump apache-airflow from 2.9.1 to 2.9.2 in /requirements/extras + * Model server override logic + +## v2.224.0 (2024-06-19) + +### Features + + * JumpStartModel attach + +### Bug Fixes and Other Changes + + * feat(sagemaker-mlflow): New features for SageMaker MLflow + * Upgrading to PT 2.3 for release + * chore: use ml.g5.2xlarge for integ test + * Enable telemetry logging for Remote function + * Fix Dependabot Issues - MLFlow Version + +## v2.223.0 (2024-06-13) + +### Features + + * add 'ModelCard' property to Register step + +### Bug Fixes and Other Changes + + * Fix Sniping bug fix + * Implement custom telemetry logging in SDK + * Fix ci unit-tests + * update image_uri_configs 06-12-2024 07:17:03 PST + +## v2.222.1 (2024-06-12) + +### Bug Fixes and Other Changes + + * First changes + * estimator.deploy not respecting instance type + +## v2.222.0 (2024-06-07) + +### Features + + * jumpstart telemetry + +### Bug Fixes and Other Changes + + * update image_uri_configs 06-06-2024 07:17:31 PST + * bump requests from 2.31.0 to 2.32.2 in /requirements/extras + * chore: add HF LLM neuronx 0.0.23 image + * Updates for DJL 0.28.0 release + * chore(deps): bump mlflow from 2.11.1 to 2.12.1 in /tests/data/serve_resources/mlflow/tensorflow + * chore(deps): bump mlflow from 2.11.1 to 2.12.1 in /tests/data/serve_resources/mlflow/xgboost + * chore(deps): bump mlflow from 2.10.2 to 2.12.1 in /tests/data/serve_resources/mlflow/pytorch + * chore(deps): bump apache-airflow from 2.9.0 to 2.9.1 in /requirements/extras + * chore(deps): bump requests from 2.31.0 to 2.32.2 in /tests/data/serve_resources/mlflow/pytorch + * Fix ci unit-tests + * Making project name in workflow files dynamic + * update image_uri_configs 05-29-2024 07:17:35 PST + * Update: SM Endpoint Routing Strategy Support. + +## v2.221.1 (2024-05-22) + +### Bug Fixes and Other Changes + + * Convert pytorchddp distribution to smdistributed distribution + * Add tei cpu image + +## v2.221.0 (2024-05-20) + +### Features + + * onboard tei image config to pysdk + +### Bug Fixes and Other Changes + + * JS Model with non-TGI/non-DJL deployment failure + * cover tei with image_uris.retrieve API + * Add more debuging + * model builder limited container support for endpoint mode. + * Image URI should take precedence for HF models + +## v2.220.0 (2024-05-15) + +### Features + + * AutoGluon 1.1.0 image_uris update + * add new images for HF TGI release + * Add telemetry support for mlflow models + +### Bug Fixes and Other Changes + + * add debug logs to workflow container dist creation + * model builder race condition on sagemaker session + * Add tensorflow_serving support for mlflow models and enable lineage tracking for mlflow models + * update image_uri_configs 05-09-2024 07:17:41 PST + * skip flakey tests pending investigation + +## v2.219.0 (2024-05-08) + +### Features + + * allow choosing js payload by alias in private method + +### Bug Fixes and Other Changes + + * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /requirements/extras + * chore(deps): bump tqdm from 4.66.2 to 4.66.3 in /tests/data/serve_resources/mlflow/pytorch + * chore(deps): bump jinja2 from 3.1.3 to 3.1.4 in /doc + * Updates for SMP v2.3.1 + +## v2.218.1 (2024-05-03) + +### Bug Fixes and Other Changes + + * Fix UserAgent logging in Python SDK + * chore: release tgi 2.0.1 + * chore: update skipped flaky tests + +## v2.218.0 (2024-05-01) + +### Features + + * set default allow_pickle param to False + +### Bug Fixes and Other Changes + + * properly close files in lineage queries and tests + ## v2.217.0 (2024-04-24) ### Features diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 24226af4ee..65b7c0ee0c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -61,6 +61,10 @@ Before sending us a pull request, please ensure that: 1. Follow the instructions at [Modifying an EBS Volume Using Elastic Volumes (Console)](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/requesting-ebs-volume-modifications.html#modify-ebs-volume) to increase the EBS volume size associated with the newly created EC2 instance. 1. Wait 5-10min for the new EBS volume increase to finalize. 1. Allow EC2 to claim the additional space by stopping and then starting your EC2 host. +2. Set up a venv to manage dependencies: + 1. `python -m venv ~/.venv/myproject-env` to create the venv + 2. `source ~/.venv/myproject-env/bin/activate` to activate the venv + 3. `deactivate` to exit the venv ### Pull Down the Code @@ -74,8 +78,8 @@ Before sending us a pull request, please ensure that: ### Run the Unit Tests 1. Install tox using `pip install tox` -1. Install coverage using `pip install .[test]` -1. cd into the sagemaker-python-sdk folder: `cd sagemaker-python-sdk` or `cd /environment/sagemaker-python-sdk` +1. cd into the github project sagemaker-python-sdk folder: `cd sagemaker-python-sdk` or `cd /environment/sagemaker-python-sdk` +1. Install coverage using `pip install '.[test]'` 1. Run the following tox command and verify that all code checks and unit tests pass: `tox tests/unit` 1. You can also run a single test with the following command: `tox -e py310 -- -s -vv ::` 1. You can run coverage via runcvoerage env : `tox -e runcoverage -- tests/unit` or `tox -e py310 -- tests/unit --cov=sagemaker --cov-append --cov-report xml` diff --git a/MANIFEST.in b/MANIFEST.in index c5eeeed043..28f1569c35 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,13 +1,16 @@ recursive-include src/sagemaker *.py include src/sagemaker/image_uri_config/*.json +include src/sagemaker/pytorch/training_recipes.json include src/sagemaker/serve/schema/*.json include src/sagemaker/serve/requirements.txt +include src/sagemaker/modules/train/sm_recipes/training_recipes.json recursive-include requirements * include VERSION include LICENSE.txt include README.rst +include hatch_build.py prune tests diff --git a/README.rst b/README.rst index e59b2da9c5..f115b1f25b 100644 --- a/README.rst +++ b/README.rst @@ -10,6 +10,10 @@ SageMaker Python SDK :target: https://pypi.python.org/pypi/sagemaker :alt: Latest Version +.. image:: https://img.shields.io/conda/vn/conda-forge/sagemaker-python-sdk.svg + :target: https://anaconda.org/conda-forge/sagemaker-python-sdk + :alt: Conda-Forge Version + .. image:: https://img.shields.io/pypi/pyversions/sagemaker.svg :target: https://pypi.python.org/pypi/sagemaker :alt: Supported Python Versions @@ -90,10 +94,17 @@ Supported Python Versions SageMaker Python SDK is tested on: -- Python 3.8 - Python 3.9 - Python 3.10 - Python 3.11 +- Python 3.12 + +Telemetry +~~~~~~~~~~~~~~~ + +The ``sagemaker`` library has telemetry enabled to help us better understand user needs, diagnose issues, and deliver new features. This telemetry tracks the usage of various SageMaker functions. + +If you prefer to opt out of telemetry, you can easily do so by setting the ``TelemetryOptOut`` parameter to ``true`` in the SDK defaults configuration. For detailed instructions, please visit `Configuring and using defaults with the SageMaker Python SDK `__. AWS Permissions ~~~~~~~~~~~~~~~ @@ -180,9 +191,9 @@ Setup a Python environment, and install the dependencies listed in ``doc/require :: # conda - conda create -n sagemaker python=3.7 + conda create -n sagemaker python=3.12 conda activate sagemaker - conda install sphinx=3.1.1 sphinx_rtd_theme=0.5.0 + conda install sphinx=5.1.1 sphinx_rtd_theme=0.5.0 # pip pip install -r doc/requirements.txt diff --git a/VERSION b/VERSION index 70303736d8..a74cccc543 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.217.1.dev0 +2.251.1.dev0 diff --git a/doc/amazon_sagemaker_model_building_pipeline.rst b/doc/amazon_sagemaker_model_building_pipeline.rst index e3548f80f2..1645302d52 100644 --- a/doc/amazon_sagemaker_model_building_pipeline.rst +++ b/doc/amazon_sagemaker_model_building_pipeline.rst @@ -408,21 +408,39 @@ Example: step_args=step_args_register_model, ) -CreateModelStep +ModelStep ```````````````` Referable Property List: - `DescribeModel`_ + OR +- `DescribeModelPackage`_ + .. _DescribeModel: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeModel.html#API_DescribeModel_ResponseSyntax +.. _DescribeModelPackage: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeModelPackage.html#API_DescribeModelPackage_ResponseSyntax Example: +For model creation usecase: + .. code-block:: python - step_model = CreateModelStep(...) - model_data = step_model.PrimaryContainer.ModelDataUrl + create_model_step = ModelStep( + name="MyModelCreationStep", + step_args = model.create(...) + ) + model_data = create_model_step.properties.PrimaryContainer.ModelDataUrl +For model registration usercase: + +.. code-block:: python + + register_model_step = ModelStep( + name="MyModelRegistrationStep", + step_args=model.register(...) + ) + approval_status=register_model_step.properties.ModelApprovalStatus LambdaStep ````````````` @@ -912,7 +930,7 @@ Caching is supported for the following step types: - :class:`sagemaker.workflow.clarify_check_step.ClarifyCheckStep` - :class:`sagemaker.workflow.emr_step.EMRStep` -In order to create pipeline steps and eventually construct a SageMaker pipeline, you provide parameters within a Python script or notebook. The SageMaker Python SDK creates a pipeline definition by translating these parameters into SageMaker job attributes. Some of these attributes, when changed, cause the step to re-run (See `Caching Pipeline Steps `__ for a detailed list). Therefore, if you update a SDK parameter that is used to create such an attribute, the step will rerun. See the following discussion for examples of this in processing and training steps, which are commonly used steps in Pipelines. +In order to create pipeline steps and eventually construct a SageMaker pipeline, you provide parameters within a Python script or notebook. The SageMaker Python SDK creates a pipeline definition by translating these parameters into SageMaker job attributes. Some of these attributes, when changed, cause the step to re-run (See `Caching Pipeline Steps `__ for a detailed list). Therefore, if you update a SDK parameter that is used to create such an attribute, the step will rerun. See the following discussion for examples of this in commonly used step types in Pipelines. The following example creates a processing step: @@ -1037,6 +1055,218 @@ The following parameters from the example cause additional training step iterati - :code:`entry_point`: The entry point file is included in the training job’s `InputDataConfig Channel `__ array. A unique hash is created from the file (and any other dependencies), and then the file is uploaded to S3 with the hash included in the path. When a different entry point file is used, a new hash is created and the S3 path for that `InputDataConfig Channel `__ object changes, initiating a new step run. For examples of what the S3 paths look like, see the **S3 Artifact Folder Structure** section. - :code:`inputs`: The inputs are also included in the training job’s `InputDataConfig `__. Local inputs are uploaded to S3. If the S3 path changes, a new training job is initiated. For examples of S3 paths, see the **S3 Artifact Folder Structure** section. +The following example creates a tuning step: + +.. code-block:: python + + from sagemaker.workflow.steps import TuningStep + from sagemaker.tuner import HyperparameterTuner + from sagemaker.estimator import Estimator + from sagemaker.inputs import TrainingInput + + model_path = f"s3://{default_bucket}/{base_job_prefix}/AbaloneTrain" + + xgb_train = Estimator( + image_uri=image_uri, + instance_type=training_instance_type, + instance_count=1, + output_path=model_path, + base_job_name=f"{base_job_prefix}/abalone-train", + sagemaker_session=pipeline_session, + role=role, + ) + + xgb_train.set_hyperparameters( + eval_metric="rmse", + objective="reg:squarederror", # Define the object metric for the training job + num_round=50, + max_depth=5, + eta=0.2, + gamma=4, + min_child_weight=6, + subsample=0.7, + silent=0, + ) + + objective_metric_name = "validation:rmse" + + hyperparameter_ranges = { + "alpha": ContinuousParameter(0.01, 10, scaling_type="Logarithmic"), + "lambda": ContinuousParameter(0.01, 10, scaling_type="Logarithmic"), + } + + tuner = HyperparameterTuner( + xgb_train, + objective_metric_name, + hyperparameter_ranges, + max_jobs=3, + max_parallel_jobs=3, + strategy="Random", + objective_type="Minimize", + ) + + hpo_args = tuner.fit( + inputs={ + "train": TrainingInput( + s3_data=step_process.properties.ProcessingOutputConfig.Outputs["train"].S3Output.S3Uri, + content_type="text/csv", + ), + "validation": TrainingInput( + s3_data=step_process.properties.ProcessingOutputConfig.Outputs[ + "validation" + ].S3Output.S3Uri, + content_type="text/csv", + ), + } + ) + + step_tuning = TuningStep( + name="HPTuning", + step_args=hpo_args, + cache_config=cache_config, + ) + +The following parameters from the example cause additional tuning (or training) step iterations when you change them: + +- :code:`image_uri`: The :code:`image_uri` parameter defines the image used for training, and is used directly in the `AlgorithmSpecification `__ attribute of the training job(s) that are created from the tuning job. +- :code:`hyperparameters`: All of the hyperparameters passed in the :code:`xgb_train.set_hyperparameters()` method are used directly in the `StaticHyperParameters `__ attribute for the tuning job. +- The following parameters are all included in the `HyperParameterTuningJobConfig `__ and if any one of them changes, a new tuning job is initiated: + - :code:`hyperparameter_ranges` + - :code:`objective_metric_name` + - :code:`max_jobs` + - :code:`max_parallel_jobs` + - :code:`strategy` + - :code:`objective_type` +- :code:`inputs`: The inputs are included in any training job’s `InputDataConfig `__ that get created from the tuning job. Local inputs are uploaded to S3. If the S3 path changes, a new tuning job is initiated. For examples of S3 paths, see the S3 Artifact Folder Structure section. + +The following examples creates a transform step: + +.. code-block:: python + + from sagemaker.transformer import Transformer + from sagemaker.inputs import TransformInput + from sagemaker.workflow.steps import TransformStep + + base_uri = f"s3://{default_bucket}/abalone" + batch_data_uri = sagemaker.s3.S3Uploader.upload( + local_path=local_path, + desired_s3_uri=base_uri, + ) + + batch_data = ParameterString( + name="BatchData", + default_value=batch_data_uri, + ) + + transformer = Transformer( + model_name=step_create_model.properties.ModelName, + instance_type="ml.m5.xlarge", + instance_count=1, + output_path=f"s3://{default_bucket}/AbaloneTransform", + env={ + 'class': 'Transformer' + } + ) + + step_transform = TransformStep( + name="AbaloneTransform", + step_args=transformer.transform( + data=batch_data, + data_type="S3Prefix" + ) + ) + +The following parameters from the example cause additional batch transform step iterations when you change them: + +- :code:`model_name`: The name of the SageMaker model being used for the transform job. +- :code:`env`: Environment variables to be set for use during the transform job. +- :code:`batch_data`: The input data will be included in the transform job’s `TransformInputfield `__. If the S3 path changes, a new transform job is initiated. + +The following example creates an automl step: + +.. code-block:: python + + from sagemaker.workflow.pipeline_context import PipelineSession + from sagemaker.workflow.automl_step import AutoMLStep + + pipeline_session = PipelineSession() + + auto_ml = AutoML(..., + role=role, + target_attribute_name="my_target_attribute_name", + mode="ENSEMBLING", + sagemaker_session=pipeline_session) + + input_training = AutoMLInput( + inputs="s3://amzn-s3-demo-bucket/my-training-data", + target_attribute_name="my_target_attribute_name", + channel_type="training", + ) + input_validation = AutoMLInput( + inputs="s3://amzn-s3-demo-bucket/my-validation-data", + target_attribute_name="my_target_attribute_name", + channel_type="validation", + ) + + step_args = auto_ml.fit( + inputs=[input_training, input_validation] + ) + + step_automl = AutoMLStep( + name="AutoMLStep", + step_args=step_args, + ) + + best_model = step_automl.get_best_auto_ml_model(role=) + +The following parameters from the example cause additional automl step iterations when you change them: + +- :code:`target_attribute_name`: The name of the target variable in supervised learning. +- :code:`mode`: The method that AutoML job uses to train the model - either AUTO, ENSEMBLING or HYPERPARAMETER_TUNING. +- :code:`inputs`: The inputs passed to the auto_ml.fit() method are included in the automl job’s `InputDataConfig `__. If the included S3 path(s) change, a new automl job is initiated. + +The following example creates an EMR step: + +.. code-block:: python + + from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig + + emr_config = EMRStepConfig( + jar="jar-location", # required, path to jar file used + args=["--verbose", "--force"], # optional list of arguments to pass to the jar + main_class="com.my.Main1", # optional main class, this can be omitted if jar above has a manifest + properties=[ # optional list of Java properties that are set when the step runs + { + "key": "mapred.tasktracker.map.tasks.maximum", + "value": "2" + }, + { + "key": "mapreduce.map.sort.spill.percent", + "value": "0.90" + }, + { + "key": "mapreduce.tasktracker.reduce.tasks.maximum", + "value": "5" + } + ] + ) + + step_emr = EMRStep( + name="EMRSampleStep", # required + cluster_id="j-1ABCDEFG2HIJK", # include cluster_id to use a running cluster + step_config=emr_config, # required + display_name="My EMR Step", + description="Pipeline step to execute EMR job" + ) + +The following parameters from the example cause additional EMR step iterations when you change them: + +- :code:`cluster_id`: The id of a running cluster to leverage for the EMR job. +- :code:`emr_config`: Configuration regarding the code that will run on the EMR cluster during the job. + +:class:`Note`: A :code:`cluster_config` parameter may also be passed into :code:`EMRStep` in order to spin up a new cluster. This parameter will also trigger additional step iterations if changed. + + S3 Artifact Folder Structure ---------------------------- diff --git a/doc/api/inference/model_builder.rst b/doc/api/inference/model_builder.rst index 3099441850..3cfbcbc2c7 100644 --- a/doc/api/inference/model_builder.rst +++ b/doc/api/inference/model_builder.rst @@ -3,14 +3,14 @@ Model Builder This module contains classes related to Amazon Sagemaker Model Builder -.. autoclass:: sagemaker.serve.builder.model_builder.ModelBuilder +.. autoclass:: sagemaker.serve.ModelBuilder -.. automethod:: sagemaker.serve.builder.model_builder.ModelBuilder.build +.. automethod:: sagemaker.serve.ModelBuilder.build -.. automethod:: sagemaker.serve.builder.model_builder.ModelBuilder.save +.. automethod:: sagemaker.serve.ModelBuilder.save -.. autoclass:: sagemaker.serve.spec.inference_spec.InferenceSpec +.. autoclass:: sagemaker.serve.InferenceSpec -.. autoclass:: sagemaker.serve.builder.schema_builder.SchemaBuilder +.. autoclass:: sagemaker.serve.SchemaBuilder -.. autoclass:: sagemaker.serve.marshalling.custom_payload_translator.CustomPayloadTranslator +.. autoclass:: sagemaker.serve.CustomPayloadTranslator diff --git a/doc/api/training/index.rst b/doc/api/training/index.rst index 5f85359d20..285d9f266d 100644 --- a/doc/api/training/index.rst +++ b/doc/api/training/index.rst @@ -3,8 +3,9 @@ Training APIs ############# .. toctree:: - :maxdepth: 4 + :maxdepth: 1 + model_trainer algorithm analytics automl diff --git a/doc/api/training/model_trainer.rst b/doc/api/training/model_trainer.rst new file mode 100644 index 0000000000..5b0781f810 --- /dev/null +++ b/doc/api/training/model_trainer.rst @@ -0,0 +1,17 @@ +ModelTrainer +------------ + +.. autoclass:: sagemaker.modules.train.model_trainer.ModelTrainer + :members: + +Configs +~~~~~~~ + +.. automodule:: sagemaker.modules.configs + :members: + +Distributed +~~~~~~~~~~~ + +.. automodule:: sagemaker.modules.distributed + :members: diff --git a/doc/conf.py b/doc/conf.py index 94a5c4d9c6..6c88ddd0e7 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -83,16 +83,11 @@ html_css_files = [ "https://cdn.datatables.net/1.10.23/css/jquery.dataTables.min.css", + "theme_overrides.css", + "pagination.css", + "search_accessories.css", ] -html_context = { - "css_files": [ - "_static/theme_overrides.css", - "_static/pagination.css", - "_static/search_accessories.css", - ] -} - # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = {"python": ("http://docs.python.org/", None)} diff --git a/doc/frameworks/djl/sagemaker.djl_inference.rst b/doc/frameworks/djl/sagemaker.djl_inference.rst index fd34ae1a23..5b4d138776 100644 --- a/doc/frameworks/djl/sagemaker.djl_inference.rst +++ b/doc/frameworks/djl/sagemaker.djl_inference.rst @@ -5,31 +5,7 @@ DJL Classes DJLModel --------------------------- -.. autoclass:: sagemaker.djl_inference.model.DJLModel - :members: - :undoc-members: - :show-inheritance: - -DeepSpeedModel ---------------------------- - -.. autoclass:: sagemaker.djl_inference.model.DeepSpeedModel - :members: - :undoc-members: - :show-inheritance: - -HuggingFaceAccelerateModel ---------------------------- - -.. autoclass:: sagemaker.djl_inference.model.HuggingFaceAccelerateModel - :members: - :undoc-members: - :show-inheritance: - -FasterTransformerModel ---------------------------- - -.. autoclass:: sagemaker.djl_inference.model.FasterTransformerModel +.. autoclass:: sagemaker.djl_inference.DJLModel :members: :undoc-members: :show-inheritance: @@ -37,7 +13,7 @@ FasterTransformerModel DJLPredictor --------------------------- -.. autoclass:: sagemaker.djl_inference.model.DJLPredictor +.. autoclass:: sagemaker.djl_inference.DJLPredictor :members: :undoc-members: :show-inheritance: diff --git a/doc/frameworks/djl/using_djl.rst b/doc/frameworks/djl/using_djl.rst index 217f5ed7dd..63b8acd684 100644 --- a/doc/frameworks/djl/using_djl.rst +++ b/doc/frameworks/djl/using_djl.rst @@ -2,14 +2,11 @@ Use DJL with the SageMaker Python SDK ####################################### -With the SageMaker Python SDK, you can use Deep Java Library to host models on Amazon SageMaker. - `Deep Java Library (DJL) Serving `_ is a high performance universal stand-alone model serving solution powered by `DJL `_. DJL Serving supports loading models trained with a variety of different frameworks. With the SageMaker Python SDK you can -use DJL Serving to host large models using backends like DeepSpeed and HuggingFace Accelerate. +use DJL Serving to host large language models for text-generation and text-embedding use-cases. -For information about supported versions of DJL Serving, see the `AWS documentation `_. -We recommend that you use the latest supported version because that's where we focus our development efforts. +You can learn more about Large Model Inference using DJLServing on the `docs site `_. For general information about using the SageMaker Python SDK, see :ref:`overview:Using the SageMaker Python SDK`. @@ -19,238 +16,57 @@ For general information about using the SageMaker Python SDK, see :ref:`overview Deploy DJL models ******************* -With the SageMaker Python SDK, you can use DJL Serving to host models that have been saved in the HuggingFace pretrained format. +With the SageMaker Python SDK, you can use DJL Serving to host text-generation and text-embedding models that have been saved in the HuggingFace pretrained format. These can either be models you have trained/fine-tuned yourself, or models available publicly from the HuggingFace Hub. -DJL Serving in the SageMaker Python SDK supports hosting models for the popular HuggingFace NLP tasks, as well as Stable Diffusion. - -You can either deploy your model using DeepSpeed, FasterTransformer, or HuggingFace Accelerate, or let DJL Serving determine the best backend based on your model architecture and configuration. .. code:: python - # Create a DJL Model, backend is chosen automatically + # DJLModel will infer which container to use, and apply some starter configuration djl_model = DJLModel( - "s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id - "my_sagemaker_role", - dtype="fp16", + model_id="", + role="my_sagemaker_role", task="text-generation", - number_of_partitions=2 # number of gpus to partition the model across ) # Deploy the model to an Amazon SageMaker Endpoint and get a Predictor predictor = djl_model.deploy("ml.g5.12xlarge", initial_instance_count=1) -If you want to use a specific backend, then you can create an instance of the corresponding model directly. +Alternatively, you can provide full specifications to the DJLModel to have full control over the model configuration: .. code:: python - # Create a model using the DeepSpeed backend - deepspeed_model = DeepSpeedModel( - "s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id - "my_sagemaker_role", - dtype="bf16", - task="text-generation", - tensor_parallel_degree=2, # number of gpus to partition the model across using tensor parallelism - ) - - # Create a model using the HuggingFace Accelerate backend - - hf_accelerate_model = HuggingFaceAccelerateModel( - "s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id - "my_sagemaker_role", - dtype="fp16", - task="text-generation", - number_of_partitions=2, # number of gpus to partition the model across - ) - - # Create a model using the FasterTransformer backend - - fastertransformer_model = FasterTransformerModel( - "s3://my_bucket/my_saved_model_artifacts/", # This can also be a HuggingFace Hub model id - "my_sagemaker_role", - data_type="fp16", + djl_model = DJLModel( + model_id="", + role="my_sagemaker_role", task="text-generation", - tensor_parallel_degree=2, # number of gpus to partition the model across + engine="Python", + env={ + "OPTION_ROLLING_BATCH": "lmi-dist", + "TENSOR_PARALLEL_DEGREE": "2", + "OPTION_DTYPE": "bf16", + "OPTION_MAX_ROLLING_BATCH_SIZE": "64", + }, + image_uri=, ) - # Deploy the model to an Amazon SageMaker Endpoint and get a Predictor - deepspeed_predictor = deepspeed_model.deploy("ml.g5.12xlarge", - initial_instance_count=1) - hf_accelerate_predictor = hf_accelerate_model.deploy("ml.g5.12xlarge", - initial_instance_count=1) - fastertransformer_predictor = fastertransformer_model.deploy("ml.g5.12xlarge", - initial_instance_count=1) - -Regardless of which way you choose to create your model, a ``Predictor`` object is returned. You can use this ``Predictor`` -to do inference on the endpoint hosting your DJLModel. + predictor = djl_model.deploy("ml.g5.12xlarge", + initial_instance_count=1) +Regardless of how you create your model, a ``Predictor`` object is returned. Each ``Predictor`` provides a ``predict`` method, which can do inference with json data, numpy arrays, or Python lists. Inference data are serialized and sent to the DJL Serving model server by an ``InvokeEndpoint`` SageMaker operation. The ``predict`` method returns the result of inference against your model. By default, the inference data is serialized to a json string, and the inference result is a Python dictionary. -Model Directory Structure -========================= - -There are two components that are needed to deploy DJL Serving Models on Sagemaker. -1. Model Artifacts (required) -2. Inference code and Model Server Properties (optional) - -These are stored and handled separately. Model artifacts should not be stored with the custom inference code and -model server configuration. - -Model Artifacts ---------------- - -DJL Serving supports two ways to load models for inference. -1. A HuggingFace Hub model id. -2. Uncompressed model artifacts stored in a S3 bucket. - -HuggingFace Hub model id -^^^^^^^^^^^^^^^^^^^^^^^^ - -Using a HuggingFace Hub model id is the easiest way to get started with deploying Large Models via DJL Serving on SageMaker. -DJL Serving will use this model id to download the model at runtime via the HuggingFace Transformers ``from_pretrained`` API. -This method makes it easy to deploy models quickly, but for very large models the download time can become unreasonable. - -For example, you can deploy the EleutherAI gpt-j-6B model like this: - -.. code:: - - model = DJLModel( - "EleutherAI/gpt-j-6B", - "my_sagemaker_role", - dtype="fp16", - number_of_partitions=2 - ) - - predictor = model.deploy("ml.g5.12xlarge") - -Uncompressed Model Artifacts stored in a S3 bucket -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For models that are larger than 20GB (total checkpoint size), we recommend that you store the model in S3. -Download times will be much faster compared to downloading from the HuggingFace Hub at runtime. -DJL Serving Models expect a different model structure than most of the other frameworks in the SageMaker Python SDK. -Specifically, DJLModels do not support loading models stored in tar.gz format. -This is because DJL Serving is optimized for large models, and it implements a fast downloading mechanism for large models that require the artifacts be uncompressed. - -For example, lets say you want to deploy the EleutherAI/gpt-j-6B model available on the HuggingFace Hub. -You can download the model and upload to S3 like this: - -.. code:: - - # Requires Git LFS - git clone https://huggingface.co/EleutherAI/gpt-j-6B - - # Upload to S3 - aws s3 sync gpt-j-6B s3://my_bucket/gpt-j-6B - -You would then pass "s3://my_bucket/gpt-j-6B" as ``model_id`` to the ``DJLModel`` like this: - -.. code:: - - model = DJLModel( - "s3://my_bucket/gpt-j-6B", - "my_sagemaker_role", - dtype="fp16", - number_of_partitions=2 - ) - - predictor = model.deploy("ml.g5.12xlarge") - -For language models we expect that the model weights, model config, and tokenizer config are provided in S3. The model -should be loadable from the HuggingFace Transformers AutoModelFor.from_pretrained API, where task -is the NLP task you want to host the model for. The weights must be stored as PyTorch compatible checkpoints. - -Example: - -.. code:: - - my_bucket/my_model/ - |- config.json - |- added_tokens.json - |- config.json - |- pytorch_model-*-of-*.bin # model weights can be partitioned into multiple checkpoints - |- tokenizer.json - |- tokenizer_config.json - |- vocab.json - -For Stable Diffusion models, the model should be loadable from the HuggingFace Diffusers DiffusionPipeline.from_pretrained API. - -Inference code and Model Server Properties ------------------------------------------- - -You can provide custom inference code and model server configuration by specifying the ``source_dir`` and -``entry_point`` arguments of the ``DJLModel``. These are not required. The model server configuration can be generated -based on the arguments passed to the constructor, and we provide default inference handler code for DeepSpeed, -HuggingFaceAccelerate, and Stable Diffusion. You can find these handler implementations in the `DJL Serving Github repository. `_ - -You can find documentation for the model server configurations on the `DJL Serving Docs website `_. - -The code and configuration you want to deploy can either be stored locally or in S3. These files will be bundled into -a tar.gz file that will be uploaded to SageMaker. - -For example: - -.. code:: - - sourcedir/ - |- script.py # Inference handler code - |- serving.properties # Model Server configuration file - |- requirements.txt # Additional Python requirements that will be installed at runtime via PyPi - -In the above example, sourcedir will be bundled and compressed into a tar.gz file and uploaded as part of creating the Inference Endpoint. - -The DJL Serving Model Server -============================ - -The endpoint you create with ``deploy`` runs the DJL Serving model server. -The model server loads the model from S3 and performs inference on the model in response to SageMaker ``InvokeEndpoint`` API calls. - -DJL Serving is highly customizable. You can control aspects of both model loading and model serving. Most of the model server -configuration are exposed through the ``DJLModel`` API. The SageMaker Python SDK will use the values it is passed to -create the proper configuration file used when creating the inference endpoint. You can optionally provide your own -``serving.properties`` file via the ``source_dir`` argument. You can find documentation about serving.properties in the -`DJL Serving Documentation for model specific settings. `_ - -Within the SageMaker Python SDK, DJL Serving is used in Python mode. This allows users to provide their inference script, -and data processing scripts in python. For details on how to write custom inference and data processing code, please -see the `DJL Serving Documentation on Python Mode. `_ - -For more information about DJL Serving, see the `DJL Serving documentation. `_ - -************************** -Ahead of time partitioning -************************** - -To optimize the deployment of large models that do not fit in a single GPU, the model’s tensor weights are partitioned at -runtime and each partition is loaded in individual GPU. But runtime partitioning takes significant amount of time and -memory on model loading. So, DJLModel offers an ahead of time partitioning capability for DeepSpeed and FasterTransformer -engines, which lets you partition your model weights and save them before deployment. HuggingFace does not support -tensor parallelism, so ahead of time partitioning cannot be done for it. In our experiment with GPT-J model, loading -this model with partitioned checkpoints increased the model loading time by 40%. - -`partition` method invokes an Amazon SageMaker Training job to partition the model and upload those partitioned -checkpoints to S3 bucket. You can either provide your desired S3 bucket to upload the partitioned checkpoints or it will be -uploaded to the default SageMaker S3 bucket. Please note that this S3 bucket will be remembered for deployment. When you -call `deploy` method after partition, DJLServing downloads the partitioned model checkpoints directly from the uploaded -s3 url, if available. - -.. code:: - - # partitions the model using Amazon Sagemaker Training Job. - djl_model.partition("ml.g5.12xlarge") +************************************** +DJL Serving for Large Model Inference +************************************** - predictor = deepspeed_model.deploy("ml.g5.12xlarge", - initial_instance_count=1) +You can learn more about using DJL Serving for Large Model Inference use-cases on our `documentation site `_. -*********************** -SageMaker DJL Classes -*********************** -For information about the different DJL Serving related classes in the SageMaker Python SDK, see https://sagemaker.readthedocs.io/en/stable/frameworks/djl/sagemaker.djl_inference.html. ******************************** SageMaker DJL Serving Containers diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index 73e2887440..9bd48ef984 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -21,18 +21,13 @@ To train a PyTorch model by using the SageMaker Python SDK: .. |create pytorch estimator| replace:: Create a ``sagemaker.pytorch.PyTorch`` Estimator .. _create pytorch estimator: #create-an-estimator -.. |call fit| replace:: Call the estimator's ``fit`` method -.. _call fit: #call-the-fit-method - -1. `Prepare a training script <#prepare-a-pytorch-training-script>`_ +1. `Prepare a training script <#prepare-a-pytorch-training-script>`_ OR `Choose an Amazon SageMaker HyperPod recipe`_ 2. |create pytorch estimator|_ -3. |call fit|_ +3. `Call the estimator's fit method or ModelTrainer's train method`_ Prepare a PyTorch Training Script ================================= -Your PyTorch training script must be a Python 3.6 compatible source file. - Prepare your script in a separate source file than the notebook, terminal session, or source file you're using to submit the script to SageMaker via a ``PyTorch`` Estimator. This will be discussed in further detail below. @@ -175,6 +170,16 @@ see `AWS Deep Learning Containers `__ +Choose an Amazon Sagemaker HyperPod recipe +========================================== + +Alternatively, instead of using your own training script, you can choose an +`Amazon SageMaker HyperPod recipe `_ to launch training for a supported model. +If using a recipe, you do not need to provide your own training script. You only need to determine +which recipe you want to run. You can modify a recipe as explained in the next section. + + + Create an Estimator =================== @@ -196,10 +201,121 @@ directories ('train' and 'test'). 'test': 's3://my-data-bucket/path/to/my/test/data'}) +Amazon Sagemaker HyperPod recipes +--------------------------------- +Alternatively, if you are using Amazon SageMaker HyperPod recipes, you can follow the following instructions: +Prerequisites: you need ``git`` installed on your client to access Amazon SageMaker HyperPod recipes code. -Call the fit Method -=================== +When using a recipe, you must set the ``training_recipe`` arg in place of providing a training script. +This can be a recipe from `here `_ +or a local file or a custom url. Please note that you must override the following using +``recipe_overrides``: + +* directory paths for the local container in the recipe as appropriate for Python SDK +* the output s3 URIs +* Huggingface access token +* any other recipe fields you wish to edit + +The code snippet below shows an example. +Please refer to `SageMaker docs `_ +for more details about the expected local paths in the container and the Amazon SageMaker +HyperPod recipes tutorial for more examples. +You can override the fields by either setting ``recipe_overrides`` or +providing a modified ``training_recipe`` through a local file or a custom url. +When using the recipe, any provided ``entry_point`` will be ignored. + +SageMaker will automatically set up the distribution args. +It will also determine the image to use for your model and device type, +but you can override this with the ``image_uri`` arg. + +You can also override the number of nodes in the recipe with the ``instance_count`` arg to estimator. +``source_dir`` will default to current working directory unless specified. +A local copy of training scripts and recipe will be saved in the ``source_dir``. +You can specify any additional packages you want to install for training in an optional ``requirements.txt`` in the ``source_dir``. + +Note for llama3.2 multi-modal models, you need to upgrade transformers library by providing a ``requirements.txt`` in the source file with ``transformers==4.45.2``. +Please refer to the Amazon SageMaker HyperPod recipes documentation for more details. + + +Here is an example usage for recipe ``hf_llama3_8b_seq8k_gpu_p5x16_pretrain``. + + +.. code:: python + + overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "exp_dir": "", + "explicit_log_dir": "/opt/ml/output/tensorboard", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + }, + } + pytorch_estimator = PyTorch( + output_path=output_path, + base_job_name=f"llama-recipe", + role=role, + instance_type="ml.p5.48xlarge", + training_recipe="hf_llama3_8b_seq8k_gpu_p5x16_pretrain", + recipe_overrides=recipe_overrides, + sagemaker_session=sagemaker_session, + tensorboard_output_config=tensorboard_output_config, + ) + pytorch_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data', + 'test': 's3://my-data-bucket/path/to/my/test/data'}) + + # Or alternatively with ModelTrainer + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "exp_dir": "", + "explicit_log_dir": "/opt/ml/output/tensorboard", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + }, + } + + model_trainer = ModelTrainer.from_recipe( + output_path=output_path, + base_job_name=f"llama-recipe", + training_recipe="training/llama/hf_llama3_8b_seq8k_gpu_p5x16_pretrain", + recipe_overrides=recipe_overrides, + compute=Compute(instance_type="ml.p5.48xlarge"), + sagemaker_session=sagemaker_session + ).with_tensorboard_output_config( + tensorboard_output_config=tensorboard_output_config + ) + + train_input = Input( + channel_name="train", + data_source="s3://my-data-bucket/path/to/my/training/data" + ) + + test_input = Input( + channel_name="test", + data_source="s3://my-data-bucket/path/to/my/test/data" + ) + + model_trainer.train(input_data_config=[train_input, test_input) + + +Call the estimator's fit method or ModelTrainer's train method +============================================================== You start your training script by calling ``fit`` on a ``PyTorch`` Estimator. ``fit`` takes both required and optional arguments. @@ -257,6 +373,9 @@ To initialize distributed training in your script, call `torch.distributed.init_process_group `_ with the desired backend and the rank of the current host. +Warning: Some torch features, such as (and likely not limited to) ``torch.nn.SyncBatchNorm`` +is not supported and its existence in ``init_process_group`` will cause an exception during +distributed training. .. code:: python @@ -929,6 +1048,43 @@ see `For versions 1.1 and lower <#for-versions-1.1-and-lower>`_. Where ``requirements.txt`` is an optional file that specifies dependencies on third-party libraries. +Important Packaging Instructions +-------------------------------- + +When creating your model artifact (``model.tar.gz``), follow these steps to avoid common deployment issues: + +1. Navigate to the directory containing your model files: + + .. code:: bash + + cd my_model + +2. Create the tar archive from within this directory: + + .. code:: bash + + tar czvf ../model.tar.gz * + +**Common Mistakes to Avoid:** + +* Do NOT create the archive from the parent directory using ``tar czvf model.tar.gz my_model/``. + This creates an extra directory level that will cause deployment errors. +* Ensure ``inference.py`` is directly under the ``code/`` directory in your archive. +* Verify your archive structure using: + + .. code:: bash + + tar tvf model.tar.gz + + You should see output similar to: + + :: + + model.pth + code/ + code/inference.py + code/requirements.txt + Create a ``PyTorchModel`` object -------------------------------- @@ -947,6 +1103,15 @@ Now call the :class:`sagemaker.pytorch.model.PyTorchModel` constructor to create Now you can call the ``predict()`` method to get predictions from your deployed model. +Troubleshooting +--------------- + +If you encounter a ``FileNotFoundError`` for ``inference.py``, check: + +1. That your model artifact is packaged correctly following the instructions above +2. The structure of your ``model.tar.gz`` file matches the expected layout +3. You're creating the archive from within the model directory, not from its parent + *********************************************** Attach an estimator to an existing training job *********************************************** diff --git a/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst b/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst index 1d7344fbbb..a645cd5a62 100644 --- a/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst +++ b/doc/frameworks/tensorflow/deploying_tensorflow_serving.rst @@ -64,7 +64,7 @@ If you already have existing model artifacts in S3, you can skip training and de from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge') @@ -74,7 +74,7 @@ Python-based TensorFlow serving on SageMaker has support for `Elastic Inference from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge', accelerator_type='ml.eia1.medium') diff --git a/doc/frameworks/tensorflow/using_tf.rst b/doc/frameworks/tensorflow/using_tf.rst index 1e51b5f43a..5b888f95be 100644 --- a/doc/frameworks/tensorflow/using_tf.rst +++ b/doc/frameworks/tensorflow/using_tf.rst @@ -246,7 +246,7 @@ Training with parameter servers If you specify parameter_server as the value of the distribution parameter, the container launches a parameter server thread on each instance in the training cluster, and then executes your training code. You can find more information on -TensorFlow distributed training at `TensorFlow docs `__. +TensorFlow distributed training at `TensorFlow docs `__. To enable parameter server training: .. code:: python @@ -468,7 +468,7 @@ If you already have existing model artifacts in S3, you can skip training and de from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge') @@ -478,7 +478,7 @@ Python-based TensorFlow serving on SageMaker has support for `Elastic Inference from sagemaker.tensorflow import TensorFlowModel - model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') + model = TensorFlowModel(model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole', framework_version='x.x.x') predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge', accelerator_type='ml.eia1.medium') @@ -767,7 +767,8 @@ This customized Python code must be named ``inference.py`` and is specified thro model = TensorFlowModel(entry_point='inference.py', model_data='s3://mybucket/model.tar.gz', - role='MySageMakerRole') + role='MySageMakerRole', + framework_version='x.x.x') In the example above, ``inference.py`` is assumed to be a file inside ``model.tar.gz``. If you want to use a local file instead, you must add the ``source_dir`` argument. See the documentation on `TensorFlowModel `_. @@ -923,7 +924,8 @@ processing. There are 2 ways to do this: model = TensorFlowModel(entry_point='inference.py', dependencies=['requirements.txt'], model_data='s3://mybucket/model.tar.gz', - role='MySageMakerRole') + role='MySageMakerRole', + framework_version='x.x.x') 2. If you are working in a network-isolation situation or if you don't @@ -941,7 +943,8 @@ processing. There are 2 ways to do this: model = TensorFlowModel(entry_point='inference.py', dependencies=['/path/to/folder/named/lib'], model_data='s3://mybucket/model.tar.gz', - role='MySageMakerRole') + role='MySageMakerRole', + framework_version='x.x.x') For more information, see: https://github.com/aws/sagemaker-tensorflow-serving-container#prepost-processing diff --git a/doc/overview.rst b/doc/overview.rst index 319560b5ff..26601900bd 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -4,6 +4,7 @@ Using the SageMaker Python SDK SageMaker Python SDK provides several high-level abstractions for working with Amazon SageMaker. These are: +- **ModelTrainer**: New interface encapsulating training on SageMaker. - **Estimators**: Encapsulate training on SageMaker. - **Models**: Encapsulate built ML models. - **Predictors**: Provide real-time inference and transformation using Python data-types against a SageMaker endpoint. @@ -24,11 +25,16 @@ Train a Model with the SageMaker Python SDK To train a model by using the SageMaker Python SDK, you: 1. Prepare a training script -2. Create an estimator -3. Call the ``fit`` method of the estimator +2. Create a ModelTrainer or Estimator +3. Call the ``train`` method of the ModelTrainer or the ``fit`` method of the Estimator After you train a model, you can save it, and then serve the model as an endpoint to get real-time inferences or get inferences for an entire dataset by using batch transform. + +Important Note: + +* When using torch to load Models, it is recommended to use version torch>=2.6.0 and torchvision>=0.17.0 + Prepare a Training script ========================= @@ -85,6 +91,46 @@ If you want to use, for example, boolean hyperparameters, you need to specify `` For more on training environment variables, please visit `SageMaker Containers `_. +Using ModelTrainer +================== + +To use the ModelTrainer class, you need to provide a few essential parameters such as the training image URI and the source code configuration. The class allows you to spin up a SageMaker training job with minimal parameters, particularly by specifying the source code and training image. + +For more information about class definitions see `ModelTrainer `_. + +Example: Launching a Training Job with Custom Script + +.. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import SourceCode, InputData + + # Image URI for the training job + pytorch_image = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310" + + # Define the script to be run + source_code = SourceCode( + source_dir="basic-script-mode", + requirements="requirements.txt", + entry_script="custom_script.py", + ) + + # Define the ModelTrainer + model_trainer = ModelTrainer( + training_image=pytorch_image, + source_code=source_code, + base_job_name="script-mode", + ) + + # Pass the input data + input_data = InputData( + channel_name="train", + data_source=training_input_path, # S3 path where training data is stored + ) + + # Start the training job + model_trainer.train(input_data_config=[input_data], wait=False) + Using Estimators ================ @@ -1917,7 +1963,7 @@ Make sure to have a Compose Version compatible with your Docker Engine installat Local mode configuration ======================== -The local mode uses a YAML configuration file located at ``~/.sagemaker/config.yaml`` to define the default values that are automatically passed to the ``config`` attribute of ``LocalSession``. This is an example of the configuration, for the full schema, see `sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA `_. +The local mode uses a YAML configuration file located at ``${user_config_directory}/sagemaker/config.yaml`` to define the default values that are automatically passed to the ``config`` attribute of ``LocalSession``. This is an example of the configuration, for the full schema, see `sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA `_. .. code:: yaml @@ -1925,7 +1971,7 @@ The local mode uses a YAML configuration file located at ``~/.sagemaker/config.y local_code: true # Using everything locally region_name: "us-west-2" # Name of the region container_config: # Additional docker container config - shm_size: "128M + shm_size: "128M" If you want to keep everything local, and not use Amazon S3 either, you can enable "local code" in one of two ways: @@ -2524,6 +2570,9 @@ set default values for. For the full schema, see `sagemaker.config.config_schema       KmsKeyId: 'kmskeyid10'     TransformResources:       VolumeKmsKeyId: 'volumekmskeyid4' + Tags: +     - Key: 'tag_key' +       Value: 'tag_value   CompilationJob:   # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateCompilationJob.html     OutputConfig: diff --git a/doc/requirements.txt b/doc/requirements.txt index a65e0e4050..11098e2bc1 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,7 +1,8 @@ -sphinx==5.1.1 -sphinx-rtd-theme==0.5.0 -docutils==0.15.2 -packaging==20.9 -jinja2==3.1.3 +sphinx==7.2.6 +sphinx-rtd-theme==3.0.0 +docutils>=0.18.1,<0.21 +packaging>=23.0,<25 +jinja2==3.1.6 schema==0.7.5 accelerate>=0.24.1,<=0.27.0 +graphene<4.0 diff --git a/doc/v2.rst b/doc/v2.rst index 0677594b31..bca663af33 100644 --- a/doc/v2.rst +++ b/doc/v2.rst @@ -324,9 +324,9 @@ The follow serializer/deserializer classes have been renamed and/or moved: +--------------------------------------------------------+-------------------------------------------------------+ | ``sagemaker.predictor._NPYSerializer`` | ``sagemaker.serializers.NumpySerializer`` | +--------------------------------------------------------+-------------------------------------------------------+ -| ``sagemaker.amazon.common.numpy_to_record_serializer`` | ``sagemaker.amazon.common.RecordSerializer`` | +| ``sagemaker.amazon.common.numpy_to_record_serializer`` | ``sagemaker.serializers.RecordSerializer`` | +--------------------------------------------------------+-------------------------------------------------------+ -| ``sagemaker.amazon.common.record_deserializer`` | ``sagemaker.amazon.common.RecordDeserializer`` | +| ``sagemaker.amazon.common.record_deserializer`` | ``sagemaker.deserializers.RecordDeserializer`` | +--------------------------------------------------------+-------------------------------------------------------+ | ``sagemaker.predictor._JsonDeserializer`` | ``sagemaker.deserializers.JSONDeserializer`` | +--------------------------------------------------------+-------------------------------------------------------+ diff --git a/doc/workflows/step_functions/index.rst b/doc/workflows/step_functions/index.rst index a327d376a0..bfe9582341 100644 --- a/doc/workflows/step_functions/index.rst +++ b/doc/workflows/step_functions/index.rst @@ -11,5 +11,5 @@ without having to provision and integrate the AWS services separately. The AWS Step Functions Python SDK uses the SageMaker Python SDK as a dependency. To get started with step functions, try the workshop or visit the SDK's website: -* `Workshop on using AWS Step Functions with SageMaker `__ +* `Create and manage Amazon SageMaker AI jobs with Step Functions `__ * `AWS Step Functions Python SDK website `__ diff --git a/hatch_build.py b/hatch_build.py new file mode 100644 index 0000000000..fc75584f17 --- /dev/null +++ b/hatch_build.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import + +import os +import sys + +from hatchling.metadata.plugin.interface import MetadataHookInterface + + +class CustomMetadataHook(MetadataHookInterface): + def update(self, metadata): + metadata["optional-dependencies"] = get_optional_dependencies(self.root) + + +def get_optional_dependencies(root): + + def read_feature_deps(feature): + req_file = os.path.join(root, "requirements", "extras", f"{feature}_requirements.txt") + with open(req_file, encoding="utf-8") as f: + return list(filter(lambda d: not d.startswith("#"), f.read().splitlines())) + + optional_dependencies = {"all": []} + + for feature in ("feature-processor", "huggingface", "local", "scipy", "sagemaker-mlflow"): + dependencies = read_feature_deps(feature) + optional_dependencies[feature] = dependencies + optional_dependencies["all"].extend(dependencies) + + # Test dependencies come last because we don't want them in `all` + optional_dependencies["test"] = read_feature_deps("test") + optional_dependencies["test"].extend(optional_dependencies["all"]) + + # remove torch and torchvision if python version is not 3.10/3.11 + if sys.version_info.minor not in (10, 11): + optional_dependencies["test"] = list( + filter( + lambda d: not d.startswith( + ("sentencepiece", "transformers", "torch", "torchvision") + ), + optional_dependencies["test"], + ) + ) + + return optional_dependencies diff --git a/pyproject.toml b/pyproject.toml index aa4949aa1c..e35a43c163 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,97 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "sagemaker" +dynamic = ["version", "optional-dependencies"] +description = "Open source library for training and deploying models on Amazon SageMaker." +readme = "README.rst" +requires-python = ">=3.9" +authors = [ + { name = "Amazon Web Services" }, +] +keywords = [ + "AI", + "AWS", + "Amazon", + "ML", + "MXNet", + "Tensorflow", +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies = [ + "attrs>=24,<26", + "boto3>=1.39.5,<2.0", + "cloudpickle>=2.2.1", + "docker", + "fastapi", + "google-pasta", + "importlib-metadata>=1.4.0,<7.0", + "jsonschema", + "numpy==1.26.4", + "omegaconf>=2.2,<3", + "packaging>=23.0,<25", + "pandas", + "pathos", + "platformdirs", + "protobuf>=3.12,<6.32", + "psutil", + "PyYAML>=6.0.1", + "requests", + "sagemaker-core>=1.0.17,<2.0.0", + "schema", + "smdebug_rulesconfig==1.0.1", + "tblib>=1.7.0,<4", + "tqdm", + "urllib3>=1.26.8,<3.0.0", + "uvicorn", + "graphene>=3,<4" +] + +[project.scripts] +sagemaker-upgrade-v2 = "sagemaker.cli.compatibility.v2.sagemaker_upgrade_v2:main" + +[project.urls] +Homepage = "https://github.com/aws/sagemaker-python-sdk" + +[tool.hatch.version] +path = "VERSION" +pattern = "(?P.+)" + +# Dynamically define optional dependencies from requirements.txt files so +# they can be be tracked by Dependabot +[tool.hatch.metadata.hooks.custom] + +[tool.hatch.build.targets.wheel] +core-metadata-version = "2.3" +packages = ["src/sagemaker"] +exclude = ["src/sagemaker/serve/model_server/triton/pack_conda_env.sh"] + +[tool.hatch.build.targets.wheel.shared-scripts] +"src/sagemaker/serve/model_server/triton/pack_conda_env.sh" = "pack_conda_env.sh" + +[tool.hatch.build.targets.sdist] +core-metadata-version = "2.3" +only-include = [ + "/requirements/extras", + "/src", + "/VERSION", +] + +[tool.pytest.ini_options] +addopts = ["-vv"] +testpaths = ["tests"] + [tool.black] line-length = 100 diff --git a/requirements/extras/feature-processor_requirements.txt b/requirements/extras/feature-processor_requirements.txt index 0d844a192a..affb6c7bc5 100644 --- a/requirements/extras/feature-processor_requirements.txt +++ b/requirements/extras/feature-processor_requirements.txt @@ -1,2 +1,2 @@ -pyspark==3.3.1 +pyspark==3.3.2 sagemaker-feature-store-pyspark-3.3 diff --git a/requirements/extras/huggingface_requirements.txt b/requirements/extras/huggingface_requirements.txt index c7ec458ea5..3ee6208618 100644 --- a/requirements/extras/huggingface_requirements.txt +++ b/requirements/extras/huggingface_requirements.txt @@ -1,2 +1,5 @@ accelerate>=0.24.1,<=0.27.0 sagemaker_schema_inference_artifacts>=0.0.5 +uvicorn>=0.30.1 +fastapi>=0.111.0 +nest-asyncio diff --git a/requirements/extras/local_requirements.txt b/requirements/extras/local_requirements.txt index c182360128..ea57b82e9a 100644 --- a/requirements/extras/local_requirements.txt +++ b/requirements/extras/local_requirements.txt @@ -1,3 +1,3 @@ urllib3>=1.26.8,<3.0.0 -docker>=5.0.2,<7.0.0 -PyYAML>=5.4.1,<7 +docker>=5.0.2,<8.0.0 +PyYAML>=6.0.1,<7 diff --git a/requirements/extras/sagemaker-mlflow_requirements.txt b/requirements/extras/sagemaker-mlflow_requirements.txt new file mode 100644 index 0000000000..75f330b0e6 --- /dev/null +++ b/requirements/extras/sagemaker-mlflow_requirements.txt @@ -0,0 +1 @@ +sagemaker-mlflow>=0.1.0 diff --git a/requirements/extras/scipy_requirements.txt b/requirements/extras/scipy_requirements.txt index 0e99587e6e..44ce1d9331 100644 --- a/requirements/extras/scipy_requirements.txt +++ b/requirements/extras/scipy_requirements.txt @@ -1 +1 @@ -scipy==1.10.1 +scipy==1.11.3 diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 43da930636..d66235d84a 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -1,5 +1,7 @@ tox==3.24.5 -flake8==4.0.1 +numpy==1.26.4 +build[virtualenv]==1.2.1 +flake8==7.1.2 pytest==6.2.5 pytest-cov==3.0.0 pytest-rerunfailures==10.2 @@ -12,27 +14,50 @@ awslogs==0.14.0 black==24.3.0 stopit==1.1.2 # Update tox.ini to have correct version of airflow constraints file -apache-airflow==2.9.0 +apache-airflow==2.10.4 apache-airflow-providers-amazon==7.2.1 -attrs>=23.1.0,<24 -fabric==2.6.0 -requests==2.31.0 +Flask-Limiter==3.11 +attrs>=24,<26 +fabric==3.2.2 +requests==2.32.2 sagemaker-experiments==0.1.35 -Jinja2==3.1.3 +Jinja2==3.1.6 pyvis==0.2.1 -pandas>=1.3.5,<1.5 +pandas==1.4.4 scikit-learn==1.3.0 cloudpickle==2.2.1 -PyYAML==6.0 +jsonpickle<4.0.0 +PyYAML>=6.0.1 # TODO find workaround xgboost>=1.6.2,<=1.7.6 pillow>=10.0.1,<=11 -transformers>=4.36.0 +opentelemetry-proto==1.27.0 +opentelemetry_exporter_otlp==1.27.0 +protobuf==4.25.8 +tensorboard>=2.16.2,<=2.18.0 +transformers==4.48.0 sentencepiece==0.1.99 # https://github.com/triton-inference-server/server/issues/6246 tritonclient[http]<2.37.0 -onnx>=1.15.0 +onnx==1.17.0 # tf2onnx==1.15.1 nbformat>=5.9,<6 accelerate>=0.24.1,<=0.27.0 schema==0.7.5 +tensorflow>=2.16.2,<=2.18.0 +mlflow>=2.14.2,<3 +huggingface_hub==0.26.2 +uvicorn>=0.30.1 +fastapi==0.115.4 +nest-asyncio +sagemaker-mlflow>=0.1.0 +deepdiff>=8.0.0 +orderly-set<5.4.0 +lexicon +networkx==3.2.1 +mypy-boto3-appflow==1.35.39 +mypy-boto3-rds==1.35.72 +mypy-boto3-redshift-data==1.35.51 +mypy-boto3-s3==1.35.76 +mypy-extensions==1.0.0 +mypy==1.9.0 diff --git a/requirements/tox/doc8_requirements.txt b/requirements/tox/doc8_requirements.txt index e4a040dd4d..8707c06621 100644 --- a/requirements/tox/doc8_requirements.txt +++ b/requirements/tox/doc8_requirements.txt @@ -1,2 +1,2 @@ -doc8==0.10.1 -Pygments==2.15.0 +doc8==1.1.2 +Pygments==2.18.0 diff --git a/requirements/tox/flake8_requirements.txt b/requirements/tox/flake8_requirements.txt index b3ccfca84f..63a79da444 100644 --- a/requirements/tox/flake8_requirements.txt +++ b/requirements/tox/flake8_requirements.txt @@ -1,2 +1,2 @@ -flake8==4.0.1 -flake8-future-import==0.4.6 +flake8==7.1.2 +flake8-future-import==0.4.7 diff --git a/requirements/tox/pylint_requirements.txt b/requirements/tox/pylint_requirements.txt index b307f21762..0e5db209fe 100644 --- a/requirements/tox/pylint_requirements.txt +++ b/requirements/tox/pylint_requirements.txt @@ -1,2 +1,2 @@ -pylint==2.6.2 -astroid==2.4.2 +pylint==3.0.3 +astroid==3.0.2 diff --git a/requirements/tox/spelling_requirements.txt b/requirements/tox/spelling_requirements.txt index 769415eb2c..94d6bc314e 100644 --- a/requirements/tox/spelling_requirements.txt +++ b/requirements/tox/spelling_requirements.txt @@ -1,2 +1,2 @@ pyenchant==3.2.2 -pylint==2.6.2 +pylint==3.0.3 diff --git a/requirements/tox/twine_requirements.txt b/requirements/tox/twine_requirements.txt index 489eeb83e0..9c0a7cab5e 100644 --- a/requirements/tox/twine_requirements.txt +++ b/requirements/tox/twine_requirements.txt @@ -1 +1,2 @@ +build==1.2.1 twine==5.0.0 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 80eaced105..0000000000 --- a/setup.cfg +++ /dev/null @@ -1,12 +0,0 @@ -# Test args for pytest; disable stdout capturing by default. -[tool:pytest] -addopts = - -vv -testpaths = tests - -[aliases] -test=pytest - -[metadata] -description_file = README.rst -license_files = LICENSE.txt diff --git a/setup.py b/setup.py index 9242e69cfd..f651c27898 100644 --- a/setup.py +++ b/setup.py @@ -14,122 +14,55 @@ from __future__ import absolute_import import os -from glob import glob +import re import sys +from ast import literal_eval +from glob import glob +from pathlib import Path from setuptools import find_packages, setup - -def read(fname): +sys.stderr.write( """ - Args: - fname: - """ - return open(os.path.join(os.path.dirname(__file__), fname)).read() - +=============================== +Unsupported installation method +=============================== -def read_version(): - return read("VERSION").strip() +This version of sagemaker no longer supports installation with `python setup.py install`. +Please use `python -m pip install .` instead. +""" +) -def read_requirements(filename): - """Reads requirements file which lists package dependencies. - - Args: - filename: type(str) Relative file path of requirements.txt file +HERE = Path(__file__).parent.absolute() +PYPROJECT = HERE.joinpath("pyproject.toml").read_text(encoding="utf-8") +BUILD_SCRIPT = HERE.joinpath("hatch_build.py").read_text(encoding="utf-8") - Returns: - list of dependencies extracted from file - """ - with open(os.path.abspath(filename)) as fp: - deps = [line.strip() for line in fp.readlines()] - return deps +def get_dependencies(): + pattern = r"^dependencies = (\[.*?\])$" + array = re.search(pattern, PYPROJECT, flags=re.MULTILINE | re.DOTALL).group(1) + return literal_eval(array) -# Declare minimal set for installation -required_packages = [ - "attrs>=23.1.0,<24", - "boto3>=1.33.3,<2.0", - "cloudpickle==2.2.1", - "google-pasta", - "numpy>=1.9.0,<2.0", - "protobuf>=3.12,<5.0", - "smdebug_rulesconfig==1.0.1", - "importlib-metadata>=1.4.0,<7.0", - "packaging>=20.0", - "pandas", - "pathos", - "schema", - "PyYAML~=6.0", - "jsonschema", - "platformdirs", - "tblib>=1.7.0,<4", - "urllib3>=1.26.8,<3.0.0", - "requests", - "docker", - "tqdm", - "psutil", -] -# Specific use case dependencies -# Keep format of *_requirements.txt to be tracked by dependabot -extras = { - "local": read_requirements("requirements/extras/local_requirements.txt"), - "scipy": read_requirements("requirements/extras/scipy_requirements.txt"), - "feature-processor": read_requirements( - "requirements/extras/feature-processor_requirements.txt" - ), - "huggingface": read_requirements("requirements/extras/huggingface_requirements.txt"), -} -# Meta dependency groups -extras["all"] = [item for group in extras.values() for item in group] -# Tests specific dependencies (do not need to be included in 'all') -test_dependencies = read_requirements("requirements/extras/test_requirements.txt") -# test dependencies are a superset of testing and extra dependencies -test_dependencies.extend(extras["all"]) -# remove torch and torchvision if python version is not 3.10/3.11 -if sys.version_info.minor != 10 or sys.version_info.minor != 11: - test_dependencies = [ - module - for module in test_dependencies - if not ( - module.startswith("transformers") - or module.startswith("sentencepiece") - or module.startswith("torch") - or module.startswith("torchvision") - ) - ] +def get_optional_dependencies(): + pattern = r"^def get_optional_dependencies.+" + function = re.search(pattern, BUILD_SCRIPT, flags=re.MULTILINE | re.DOTALL).group(0) + identifiers = {} + exec(function, None, identifiers) + return identifiers["get_optional_dependencies"](str(HERE)) -extras["test"] = (test_dependencies,) setup( name="sagemaker", - version=read_version(), - description="Open source library for training and deploying models on Amazon SageMaker.", + version=HERE.joinpath("VERSION").read_text().strip(), packages=find_packages("src"), package_dir={"": "src"}, - package_data={"": ["*.whl"]}, + package_data={"": ["*.whl", "py.typed"]}, py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob("src/*.py")], include_package_data=True, - long_description=read("README.rst"), - author="Amazon Web Services", - url="https://github.com/aws/sagemaker-python-sdk/", - license="Apache License 2.0", - keywords="ML Amazon AWS AI Tensorflow MXNet", - python_requires=">= 3.8", - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "Natural Language :: English", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - ], - install_requires=required_packages, - extras_require=extras, + install_requires=get_dependencies(), + extras_require=get_optional_dependencies(), entry_points={ "console_scripts": [ "sagemaker-upgrade-v2=sagemaker.cli.compatibility.v2.sagemaker_upgrade_v2:main", diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index a1769b5a4c..71ea51c60f 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -74,5 +74,6 @@ ) from sagemaker.debugger import ProfilerConfig, Profiler # noqa: F401 +from sagemaker.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401 __version__ = importlib_metadata.version("sagemaker") diff --git a/src/sagemaker/_studio.py b/src/sagemaker/_studio.py index a23fae87e9..22f1c94c5f 100644 --- a/src/sagemaker/_studio.py +++ b/src/sagemaker/_studio.py @@ -65,7 +65,10 @@ def _find_config(working_dir=None): wd = Path(working_dir) if working_dir else Path.cwd() path = None - while path is None and not wd.match("/"): + + # Get the root of the current working directory for both Windows and Unix-like systems + root = Path(wd.anchor) + while path is None and wd != root: candidate = wd / STUDIO_PROJECT_CONFIG if Path.exists(candidate): path = candidate diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 78aa655e04..b48adda44c 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -24,6 +24,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -37,6 +38,8 @@ def retrieve_options( retrieve the supported accept types. (Default: None). model_version (str): The version of the model for which to retrieve the supported accept types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -60,11 +63,12 @@ def retrieve_options( ) return artifacts._retrieve_supported_accept_types( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -73,10 +77,12 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default accept type for the model matching the given arguments. @@ -87,6 +93,8 @@ def retrieve_default( retrieve the default accept type. (Default: None). model_version (str): The version of the model for which to retrieve the default accept type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -98,6 +106,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default accept type to use for the model. @@ -110,11 +119,13 @@ def retrieve_default( ) return artifacts._retrieve_default_accept_type( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/algorithm.py b/src/sagemaker/algorithm.py index a177b93f03..f3fd2c954e 100644 --- a/src/sagemaker/algorithm.py +++ b/src/sagemaker/algorithm.py @@ -68,7 +68,7 @@ def __init__( encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, use_spot_instances: Union[bool, PipelineVariable] = False, max_wait: Optional[Union[int, PipelineVariable]] = None, - **kwargs # pylint: disable=W0613 + **kwargs, # pylint: disable=W0613 ): """Initialize an ``AlgorithmEstimator`` instance. @@ -157,6 +157,20 @@ def __init__( available (default: ``None``). **kwargs: Additional kwargs. This is unused. It's only added for AlgorithmEstimator to ignore the irrelevant arguments. + + Raises: + ValueError: + - If an AWS IAM Role is not provided. + - Bad value for instance type. + RuntimeError: + - When setting up custom VPC, both subnets and security_group_ids are not provided + - If instance_count > 1 (distributed training) with instance type local or local gpu + - If LocalSession is not used with instance type local or local gpu + - file:// output path used outside of local mode + botocore.exceptions.ClientError: + - algorithm arn is incorrect + - insufficient permission to access/ describe algorithm + - algorithm is in a different region """ self.algorithm_arn = algorithm_arn super(AlgorithmEstimator, self).__init__( @@ -271,7 +285,7 @@ def create_model( serializer=IdentitySerializer(), deserializer=BytesDeserializer(), vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT, - **kwargs + **kwargs, ): """Create a model to deploy. @@ -325,7 +339,7 @@ def predict_wrapper(endpoint, session): vpc_config=self.get_vpc_config(vpc_config_override), sagemaker_session=self.sagemaker_session, predictor_cls=predictor_cls, - **kwargs + **kwargs, ) def transformer( diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index bf4e44df25..9f7bc3bda6 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -55,7 +55,7 @@ def __init__( instance_type: Optional[Union[str, PipelineVariable]] = None, data_location: Optional[str] = None, enable_network_isolation: Union[bool, PipelineVariable] = False, - **kwargs + **kwargs, ): """Initialize an AmazonAlgorithmEstimatorBase. @@ -91,7 +91,7 @@ def __init__( instance_count, instance_type, enable_network_isolation=enable_network_isolation, - **kwargs + **kwargs, ) data_location = data_location or ( diff --git a/src/sagemaker/amazon/common.py b/src/sagemaker/amazon/common.py index 4632bda628..fc5d355749 100644 --- a/src/sagemaker/amazon/common.py +++ b/src/sagemaker/amazon/common.py @@ -13,282 +13,13 @@ """Placeholder docstring""" from __future__ import absolute_import -import io -import logging -import struct -import sys - -import numpy as np - -from sagemaker.amazon.record_pb2 import Record -from sagemaker.deprecations import deprecated_class -from sagemaker.deserializers import SimpleBaseDeserializer -from sagemaker.serializers import SimpleBaseSerializer -from sagemaker.utils import DeferredError - - -class RecordSerializer(SimpleBaseSerializer): - """Serialize a NumPy array for an inference request.""" - - def __init__(self, content_type="application/x-recordio-protobuf"): - """Initialize a ``RecordSerializer`` instance. - - Args: - content_type (str): The MIME type to signal to the inference endpoint when sending - request data (default: "application/x-recordio-protobuf"). - """ - super(RecordSerializer, self).__init__(content_type=content_type) - - def serialize(self, data): - """Serialize a NumPy array into a buffer containing RecordIO records. - - Args: - data (numpy.ndarray): The data to serialize. - - Returns: - io.BytesIO: A buffer containing the data serialized as records. - """ - if len(data.shape) == 1: - data = data.reshape(1, data.shape[0]) - - if len(data.shape) != 2: - raise ValueError( - "Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape) - ) - - buffer = io.BytesIO() - write_numpy_to_dense_tensor(buffer, data) - buffer.seek(0) - - return buffer - - -class RecordDeserializer(SimpleBaseDeserializer): - """Deserialize RecordIO Protobuf data from an inference endpoint.""" - - def __init__(self, accept="application/x-recordio-protobuf"): - """Initialize a ``RecordDeserializer`` instance. - - Args: - accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that - is expected from the inference endpoint (default: - "application/x-recordio-protobuf"). - """ - super(RecordDeserializer, self).__init__(accept=accept) - - def deserialize(self, data, content_type): - """Deserialize RecordIO Protobuf data from an inference endpoint. - - Args: - data (object): The protobuf message to deserialize. - content_type (str): The MIME type of the data. - Returns: - list: A list of records. - """ - try: - return read_records(data) - finally: - data.close() - - -def _write_feature_tensor(resolved_type, record, vector): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.features["values"].int32_tensor.values.extend(vector) - elif resolved_type == "Float64": - record.features["values"].float64_tensor.values.extend(vector) - elif resolved_type == "Float32": - record.features["values"].float32_tensor.values.extend(vector) - - -def _write_label_tensor(resolved_type, record, scalar): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.label["values"].int32_tensor.values.extend([scalar]) - elif resolved_type == "Float64": - record.label["values"].float64_tensor.values.extend([scalar]) - elif resolved_type == "Float32": - record.label["values"].float32_tensor.values.extend([scalar]) - - -def _write_keys_tensor(resolved_type, record, vector): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.features["values"].int32_tensor.keys.extend(vector) - elif resolved_type == "Float64": - record.features["values"].float64_tensor.keys.extend(vector) - elif resolved_type == "Float32": - record.features["values"].float32_tensor.keys.extend(vector) - - -def _write_shape(resolved_type, record, scalar): - """Placeholder Docstring""" - if resolved_type == "Int32": - record.features["values"].int32_tensor.shape.extend([scalar]) - elif resolved_type == "Float64": - record.features["values"].float64_tensor.shape.extend([scalar]) - elif resolved_type == "Float32": - record.features["values"].float32_tensor.shape.extend([scalar]) - - -def write_numpy_to_dense_tensor(file, array, labels=None): - """Writes a numpy array to a dense tensor - - Args: - file: - array: - labels: - """ - - # Validate shape of array and labels, resolve array and label types - if not len(array.shape) == 2: - raise ValueError("Array must be a Matrix") - if labels is not None: - if not len(labels.shape) == 1: - raise ValueError("Labels must be a Vector") - if labels.shape[0] not in array.shape: - raise ValueError( - "Label shape {} not compatible with array shape {}".format( - labels.shape, array.shape - ) - ) - resolved_label_type = _resolve_type(labels.dtype) - resolved_type = _resolve_type(array.dtype) - - # Write each vector in array into a Record in the file object - record = Record() - for index, vector in enumerate(array): - record.Clear() - _write_feature_tensor(resolved_type, record, vector) - if labels is not None: - _write_label_tensor(resolved_label_type, record, labels[index]) - _write_recordio(file, record.SerializeToString()) - - -def write_spmatrix_to_sparse_tensor(file, array, labels=None): - """Writes a scipy sparse matrix to a sparse tensor - - Args: - file: - array: - labels: - """ - try: - import scipy - except ImportError as e: - logging.warning( - "scipy failed to import. Sparse matrix functions will be impaired or broken." - ) - # Any subsequent attempt to use scipy will raise the ImportError - scipy = DeferredError(e) - - if not scipy.sparse.issparse(array): - raise TypeError("Array must be sparse") - - # Validate shape of array and labels, resolve array and label types - if not len(array.shape) == 2: - raise ValueError("Array must be a Matrix") - if labels is not None: - if not len(labels.shape) == 1: - raise ValueError("Labels must be a Vector") - if labels.shape[0] not in array.shape: - raise ValueError( - "Label shape {} not compatible with array shape {}".format( - labels.shape, array.shape - ) - ) - resolved_label_type = _resolve_type(labels.dtype) - resolved_type = _resolve_type(array.dtype) - - csr_array = array.tocsr() - n_rows, n_cols = csr_array.shape - - record = Record() - for row_idx in range(n_rows): - record.Clear() - row = csr_array.getrow(row_idx) - # Write values - _write_feature_tensor(resolved_type, record, row.data) - # Write keys - _write_keys_tensor(resolved_type, record, row.indices.astype(np.uint64)) - - # Write labels - if labels is not None: - _write_label_tensor(resolved_label_type, record, labels[row_idx]) - - # Write shape - _write_shape(resolved_type, record, n_cols) - - _write_recordio(file, record.SerializeToString()) - - -def read_records(file): - """Eagerly read a collection of amazon Record protobuf objects from file. - - Args: - file: - """ - records = [] - for record_data in read_recordio(file): - record = Record() - record.ParseFromString(record_data) - records.append(record) - return records - - -# MXNet requires recordio records have length in bytes that's a multiple of 4 -# This sets up padding bytes to append to the end of the record, for diferent -# amounts of padding required. -padding = {} -for amount in range(4): - if sys.version_info >= (3,): - padding[amount] = bytes([0x00 for _ in range(amount)]) - else: - padding[amount] = bytearray([0x00 for _ in range(amount)]) - -_kmagic = 0xCED7230A - - -def _write_recordio(f, data): - """Writes a single data point as a RecordIO record to the given file. - - Args: - f: - data: - """ - length = len(data) - f.write(struct.pack("I", _kmagic)) - f.write(struct.pack("I", length)) - pad = (((length + 3) >> 2) << 2) - length - f.write(data) - f.write(padding[pad]) - - -def read_recordio(f): - """Placeholder Docstring""" - while True: - try: - (read_kmagic,) = struct.unpack("I", f.read(4)) - except struct.error: - return - assert read_kmagic == _kmagic - (len_record,) = struct.unpack("I", f.read(4)) - pad = (((len_record + 3) >> 2) << 2) - len_record - yield f.read(len_record) - if pad: - f.read(pad) - - -def _resolve_type(dtype): - """Placeholder Docstring""" - if dtype == np.dtype(int): - return "Int32" - if dtype == np.dtype(float): - return "Float64" - if dtype == np.dtype("float32"): - return "Float32" - raise ValueError("Unsupported dtype {} on array".format(dtype)) - - -numpy_to_record_serializer = deprecated_class(RecordSerializer, "numpy_to_record_serializer") -record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer") +# these imports ensure backward compatibility. +from sagemaker.deserializers import RecordDeserializer # noqa: F401 # pylint: disable=W0611 +from sagemaker.serializers import RecordSerializer # noqa: F401 # pylint: disable=W0611 +from sagemaker.serializer_utils import ( # noqa: F401 # pylint: disable=W0611 + read_recordio, + read_records, + write_numpy_to_dense_tensor, + write_spmatrix_to_sparse_tensor, + _write_recordio, +) diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index e651ee1460..1149cd02b2 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin, ge +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -114,7 +115,7 @@ def __init__( factors_init_scale: Optional[float] = None, factors_init_sigma: Optional[float] = None, factors_init_value: Optional[float] = None, - **kwargs + **kwargs, ): """Factorization Machines is :class:`Estimator` for general-purpose supervised learning. @@ -266,7 +267,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): self.role, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), - **kwargs + **kwargs, ) @@ -332,7 +333,7 @@ def __init__( model_data: Union[str, PipelineVariable], role: Optional[str] = None, sagemaker_session: Optional[Session] = None, - **kwargs + **kwargs, ): """Initialization for FactorizationMachinesModel class. @@ -365,5 +366,5 @@ def __init__( role, predictor_cls=FactorizationMachinesPredictor, sagemaker_session=sagemaker_session, - **kwargs + **kwargs, ) diff --git a/src/sagemaker/amazon/hyperparameter.py b/src/sagemaker/amazon/hyperparameter.py index 856927cb13..b479f8a271 100644 --- a/src/sagemaker/amazon/hyperparameter.py +++ b/src/sagemaker/amazon/hyperparameter.py @@ -28,7 +28,7 @@ def __init__(self, name, validate=lambda _: True, validation_message="", data_ty """Args: name (str): The name of this hyperparameter validate - (callable[object]->[bool]): A validation function or list of validation + (Callable[object]->[bool]): A validation function or list of validation functions. Each function validates an object and returns False if the object diff --git a/src/sagemaker/amazon/ipinsights.py b/src/sagemaker/amazon/ipinsights.py index 6306c38069..bc8e1b5d86 100644 --- a/src/sagemaker/amazon/ipinsights.py +++ b/src/sagemaker/amazon/ipinsights.py @@ -75,7 +75,7 @@ def __init__( random_negative_sampling_rate: Optional[int] = None, shuffled_negative_sampling_rate: Optional[int] = None, weight_decay: Optional[float] = None, - **kwargs + **kwargs, ): """This estimator is for IP Insights. @@ -168,7 +168,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): self.role, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), - **kwargs + **kwargs, ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): @@ -209,7 +209,7 @@ def __init__( chain. serializer (sagemaker.serializers.BaseSerializer): Optional. Default serializes input data to text/csv. - deserializer (callable): Optional. Default parses JSON responses + deserializer (Callable): Optional. Default parses JSON responses using ``json.load(...)``. component_name (str): Optional. Name of the Amazon SageMaker inference component corresponding the predictor. @@ -235,7 +235,7 @@ def __init__( model_data: Union[str, PipelineVariable], role: Optional[str] = None, sagemaker_session: Optional[Session] = None, - **kwargs + **kwargs, ): """Creates object to get insights on S3 model data. @@ -268,5 +268,5 @@ def __init__( role, predictor_cls=IPInsightsPredictor, sagemaker_session=sagemaker_session, - **kwargs + **kwargs, ) diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index b52a042766..25abb9cb27 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin, ge, le +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -75,7 +76,7 @@ def __init__( epochs: Optional[int] = None, center_factor: Optional[int] = None, eval_metrics: Optional[List[Union[str, PipelineVariable]]] = None, - **kwargs + **kwargs, ): """A k-means clustering class :class:`~sagemaker.amazon.AmazonAlgorithmEstimatorBase`. @@ -184,7 +185,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): self.role, self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), - **kwargs + **kwargs, ) def _prepare_for_training(self, records, mini_batch_size=5000, job_name=None): @@ -261,7 +262,7 @@ def __init__( model_data: Union[str, PipelineVariable], role: Optional[str] = None, sagemaker_session: Optional[Session] = None, - **kwargs + **kwargs, ): """Initialization for KMeansModel class. @@ -294,5 +295,5 @@ def __init__( role, predictor_cls=KMeansPredictor, sagemaker_session=sagemaker_session, - **kwargs + **kwargs, ) diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index a57070dfd0..89ec979e09 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, isin +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -85,7 +86,7 @@ def __init__( index_metric: Optional[str] = None, faiss_index_ivf_nlists: Optional[str] = None, faiss_index_pq_m: Optional[int] = None, - **kwargs + **kwargs, ): """k-nearest neighbors (KNN) is :class:`Estimator` used for classification and regression. @@ -181,7 +182,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): self.role, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), - **kwargs + **kwargs, ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): @@ -252,7 +253,7 @@ def __init__( model_data: Union[str, PipelineVariable], role: Optional[str] = None, sagemaker_session: Optional[Session] = None, - **kwargs + **kwargs, ): """Function to initialize KNNModel. @@ -285,5 +286,5 @@ def __init__( role, predictor_cls=KNNPredictor, sagemaker_session=sagemaker_session, - **kwargs + **kwargs, ) diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index d47b6ecad8..c57da9643e 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -18,11 +18,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer +from sagemaker.deserializers import RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -59,7 +60,7 @@ def __init__( max_restarts: Optional[int] = None, max_iterations: Optional[int] = None, tol: Optional[float] = None, - **kwargs + **kwargs, ): """Latent Dirichlet Allocation (LDA) is :class:`Estimator` used for unsupervised learning. @@ -159,7 +160,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): self.role, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), - **kwargs + **kwargs, ) def _prepare_for_training( # pylint: disable=signature-differs @@ -236,7 +237,7 @@ def __init__( model_data: Union[str, PipelineVariable], role: Optional[str] = None, sagemaker_session: Optional[Session] = None, - **kwargs + **kwargs, ): """Initialization for LDAModel class. @@ -269,5 +270,5 @@ def __init__( role, predictor_cls=LDAPredictor, sagemaker_session=sagemaker_session, - **kwargs + **kwargs, ) diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index 231f0ba344..4533dcdaea 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -18,11 +18,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer +from sagemaker.deserializers import RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import isin, gt, lt, ge, le from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -190,7 +191,7 @@ def __init__( accuracy_top_k: Optional[int] = None, f_beta: Optional[float] = None, balance_multiclass_weights: Optional[bool] = None, - **kwargs + **kwargs, ): """An :class:`Estimator` for binary classification and regression. @@ -420,7 +421,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): self.role, self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), - **kwargs + **kwargs, ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): @@ -505,7 +506,7 @@ def __init__( model_data: Union[str, PipelineVariable], role: Optional[str] = None, sagemaker_session: Optional[Session] = None, - **kwargs + **kwargs, ): """Initialization for LinearLearnerModel. @@ -538,5 +539,5 @@ def __init__( role, predictor_cls=LinearLearnerPredictor, sagemaker_session=sagemaker_session, - **kwargs + **kwargs, ) diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index ddc5b95eb2..41dde1c33c 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, le, isin +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -89,7 +90,7 @@ def __init__( clip_gradient: Optional[float] = None, weight_decay: Optional[float] = None, learning_rate: Optional[float] = None, - **kwargs + **kwargs, ): """Neural Topic Model (NTM) is :class:`Estimator` used for unsupervised learning. @@ -194,7 +195,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): self.role, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), - **kwargs + **kwargs, ) def _prepare_for_training( # pylint: disable=signature-differs @@ -269,7 +270,7 @@ def __init__( model_data: Union[str, PipelineVariable], role: Optional[str] = None, sagemaker_session: Optional[Session] = None, - **kwargs + **kwargs, ): """Initialization for NTMModel class. @@ -302,5 +303,5 @@ def __init__( role, predictor_cls=NTMPredictor, sagemaker_session=sagemaker_session, - **kwargs + **kwargs, ) diff --git a/src/sagemaker/amazon/object2vec.py b/src/sagemaker/amazon/object2vec.py index 8a967484ec..536fda0229 100644 --- a/src/sagemaker/amazon/object2vec.py +++ b/src/sagemaker/amazon/object2vec.py @@ -189,7 +189,7 @@ def __init__( enc1_layers: Optional[int] = None, enc0_freeze_pretrained_embedding: Optional[bool] = None, enc1_freeze_pretrained_embedding: Optional[bool] = None, - **kwargs + **kwargs, ): """Object2Vec is :class:`Estimator` used for anomaly detection. @@ -338,7 +338,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): self.role, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), - **kwargs + **kwargs, ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): @@ -363,7 +363,7 @@ def __init__( model_data: Union[str, PipelineVariable], role: Optional[str] = None, sagemaker_session: Optional[Session] = None, - **kwargs + **kwargs, ): """Initialization for Object2VecModel class. @@ -396,5 +396,5 @@ def __init__( role, predictor_cls=Predictor, sagemaker_session=sagemaker_session, - **kwargs + **kwargs, ) diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index 6e694211dd..b724435afa 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import gt, isin +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -67,7 +68,7 @@ def __init__( algorithm_mode: Optional[str] = None, subtract_mean: Optional[bool] = None, extra_components: Optional[int] = None, - **kwargs + **kwargs, ): """A Principal Components Analysis (PCA) @@ -155,7 +156,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): self.role, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), - **kwargs + **kwargs, ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): @@ -249,7 +250,7 @@ def __init__( model_data: Union[str, PipelineVariable], role: Optional[str] = None, sagemaker_session: Optional[Session] = None, - **kwargs + **kwargs, ): """Initialization for PCAModel. @@ -282,5 +283,5 @@ def __init__( role, predictor_cls=PCAPredictor, sagemaker_session=sagemaker_session, - **kwargs + **kwargs, ) diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index d1b3a4b9f7..d60d5a7741 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -17,11 +17,12 @@ from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase -from sagemaker.amazon.common import RecordSerializer, RecordDeserializer from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.amazon.validation import ge, le +from sagemaker.deserializers import RecordDeserializer from sagemaker.predictor import Predictor from sagemaker.model import Model +from sagemaker.serializers import RecordSerializer from sagemaker.session import Session from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT @@ -60,7 +61,7 @@ def __init__( num_samples_per_tree: Optional[int] = None, num_trees: Optional[int] = None, eval_metrics: Optional[List] = None, - **kwargs + **kwargs, ): """An `Estimator` class implementing a Random Cut Forest. @@ -144,7 +145,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs): self.role, sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), - **kwargs + **kwargs, ) def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): @@ -222,7 +223,7 @@ def __init__( model_data: Union[str, PipelineVariable], role: Optional[str] = None, sagemaker_session: Optional[Session] = None, - **kwargs + **kwargs, ): """Initialization for RandomCutForestModel class. @@ -255,5 +256,5 @@ def __init__( role, predictor_cls=RandomCutForestPredictor, sagemaker_session=sagemaker_session, - **kwargs + **kwargs, ) diff --git a/src/sagemaker/amtviz/__init__.py b/src/sagemaker/amtviz/__init__.py new file mode 100644 index 0000000000..8554b32c4a --- /dev/null +++ b/src/sagemaker/amtviz/__init__.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Amazon SageMaker Automatic Model Tuning Visualization module. + +This module provides visualization capabilities for SageMaker hyperparameter tuning jobs. +It enables users to create interactive visualizations to analyze and understand the +performance of hyperparameter optimization experiments. + +Example: + >>> from sagemaker.amtviz import visualize_tuning_job + >>> visualize_tuning_job('my-tuning-job') +""" +from __future__ import absolute_import + +from sagemaker.amtviz.visualization import visualize_tuning_job + +__all__ = ["visualize_tuning_job"] diff --git a/src/sagemaker/amtviz/job_metrics.py b/src/sagemaker/amtviz/job_metrics.py new file mode 100644 index 0000000000..b99886941f --- /dev/null +++ b/src/sagemaker/amtviz/job_metrics.py @@ -0,0 +1,180 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Helper functions to retrieve job metrics from CloudWatch.""" +from __future__ import absolute_import + +from datetime import datetime, timedelta +from typing import Callable, List, Optional, Tuple, Dict, Any +import hashlib +import os +from pathlib import Path + +import logging +import pandas as pd +import numpy as np +import boto3 + +logger = logging.getLogger(__name__) + +cw = boto3.client("cloudwatch") +sm = boto3.client("sagemaker") + + +def disk_cache(outer: Callable) -> Callable: + """A decorator that implements disk-based caching for CloudWatch metrics data. + + This decorator caches the output of the wrapped function to disk in JSON Lines format. + It creates a cache key using MD5 hash of the function arguments and stores the data + in the user's home directory under .amtviz/cw_metrics_cache/. + + Args: + outer (Callable): The function to be wrapped. Must return a pandas DataFrame + containing CloudWatch metrics data. + + Returns: + Callable: A wrapper function that implements the caching logic. + """ + + def inner(*args: Any, **kwargs: Any) -> pd.DataFrame: + key_input = str(args) + str(kwargs) + # nosec b303 - Not used for cryptography, but to create lookup key + key = hashlib.md5(key_input.encode("utf-8")).hexdigest() + cache_dir = Path.home().joinpath(".amtviz/cw_metrics_cache") + fn = f"{cache_dir}/req_{key}.jsonl.gz" + if Path(fn).exists(): + try: + df = pd.read_json(fn, lines=True) + logger.debug("H", end="") + df["ts"] = pd.to_datetime(df["ts"]) + df["ts"] = df["ts"].dt.tz_localize(None) + # pyright: ignore [reportIndexIssue, reportOptionalSubscript] + df["rel_ts"] = pd.to_datetime(df["rel_ts"]) + df["rel_ts"] = df["rel_ts"].dt.tz_localize(None) + return df + except KeyError: + # Empty file leads to empty df, hence no df['ts'] possible + pass + # nosec b110 - doesn't matter why we could not load it. + except BaseException as e: + logger.error("\nException: %s - %s", type(e), e) + + logger.debug("M", end="") + df = outer(*args, **kwargs) + assert isinstance(df, pd.DataFrame), "Only caching Pandas DataFrames." + + os.makedirs(cache_dir, exist_ok=True) + df.to_json(fn, orient="records", date_format="iso", lines=True) + + return df + + return inner + + +def _metric_data_query_tpl(metric_name: str, dim_name: str, dim_value: str) -> Dict[str, Any]: + """Returns a CloudWatch metric data query template.""" + return { + "Id": metric_name.lower().replace(":", "_").replace("-", "_"), + "MetricStat": { + "Stat": "Average", + "Metric": { + "Namespace": "/aws/sagemaker/TrainingJobs", + "MetricName": metric_name, + "Dimensions": [ + {"Name": dim_name, "Value": dim_value}, + ], + }, + "Period": 60, + }, + "ReturnData": True, + } + + +def _get_metric_data( + queries: List[Dict[str, Any]], start_time: datetime, end_time: datetime +) -> pd.DataFrame: + """Fetches CloudWatch metrics between timestamps, returns a DataFrame with selected columns.""" + start_time = start_time - timedelta(hours=1) + end_time = end_time + timedelta(hours=1) + response = cw.get_metric_data(MetricDataQueries=queries, StartTime=start_time, EndTime=end_time) + + df = pd.DataFrame() + if "MetricDataResults" not in response: + return df + + for metric_data in response["MetricDataResults"]: + values = metric_data["Values"] + ts = np.array(metric_data["Timestamps"], dtype=np.datetime64) + labels = [metric_data["Label"]] * len(values) + + df = pd.concat([df, pd.DataFrame({"value": values, "ts": ts, "label": labels})]) + + # We now calculate the relative time based on the first actual observed + # time stamps, not the potentially start time that we used to scope our CW + # API call. The difference could be for example startup times or waiting + # for Spot. + if not df.empty: + df["rel_ts"] = datetime.fromtimestamp(1) + (df["ts"] - df["ts"].min()) # pyright: ignore + return df + + +@disk_cache +def _collect_metrics( + dimensions: List[Tuple[str, str]], start_time: datetime, end_time: Optional[datetime] +) -> pd.DataFrame: + """Collects SageMaker training job metrics from CloudWatch for dimensions and time range.""" + df = pd.DataFrame() + for dim_name, dim_value in dimensions: + response = cw.list_metrics( + Namespace="/aws/sagemaker/TrainingJobs", + Dimensions=[ + {"Name": dim_name, "Value": dim_value}, + ], + ) + if not response["Metrics"]: + continue + metric_names = [metric["MetricName"] for metric in response["Metrics"]] + if not metric_names: + # No metric data yet, or not any longer, because the data were aged out + continue + metric_data_queries = [ + _metric_data_query_tpl(metric_name, dim_name, dim_value) for metric_name in metric_names + ] + df = pd.concat([df, _get_metric_data(metric_data_queries, start_time, end_time)]) + + return df + + +def get_cw_job_metrics( + job_name: str, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None +) -> pd.DataFrame: + """Retrieves CloudWatch metrics for a SageMaker training job. + + Args: + job_name (str): Name of the SageMaker training job. + start_time (datetime, optional): Start time for metrics collection. + Defaults to now - 4 hours. + end_time (datetime, optional): End time for metrics collection. + Defaults to start_time + 4 hours. + + Returns: + pd.DataFrame: Metrics data with columns for value, timestamp, and metric name. + Results are cached to disk for improved performance. + """ + dimensions = [ + ("TrainingJobName", job_name), + ("Host", job_name + "/algo-1"), + ] + # If not given, use reasonable defaults for start and end time + start_time = start_time or datetime.now() - timedelta(hours=4) + end_time = end_time or start_time + timedelta(hours=4) + return _collect_metrics(dimensions, start_time, end_time) diff --git a/src/sagemaker/amtviz/visualization.py b/src/sagemaker/amtviz/visualization.py new file mode 100644 index 0000000000..7f09117d1e --- /dev/null +++ b/src/sagemaker/amtviz/visualization.py @@ -0,0 +1,857 @@ +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides visualization capabilities for SageMaker hyperparameter tuning jobs. + +It contains utilities to create interactive visualizations of hyperparameter tuning results +using Altair charts. The module enables users to analyze and understand the performance +of their hyperparameter optimization experiments through various visual representations +including: +- Progress of objective metrics over time +- Distribution of results +- Relationship between hyperparameters and objective values +- Training job metrics and instance utilization +- Comparative analysis across multiple tuning jobs + +Main Features: + - Visualize single or multiple hyperparameter tuning jobs + - Display training job metrics from CloudWatch + - Support for both completed and in-progress tuning jobs + - Interactive filtering and highlighting of data points + - CPU, memory, and GPU utilization visualization + - Advanced visualization options for detailed analysis + +Primary Classes and Functions: + - visualize_tuning_job: Main function to create visualizations for tuning jobs + - create_charts: Core chart creation functionality + - get_job_analytics_data: Retrieves and processes tuning job data + +Dependencies: + - altair: For creating interactive visualizations + - pandas: For data manipulation and analysis + - boto3: For AWS service interaction + - sagemaker: For accessing SageMaker resources +""" +from __future__ import absolute_import + +from typing import Union, List, Optional, Tuple +import os +import warnings +import logging +import altair as alt +import pandas as pd +import numpy as np +import boto3 +import sagemaker +from sagemaker.amtviz.job_metrics import get_cw_job_metrics + +warnings.filterwarnings("ignore") +logger = logging.getLogger(__name__) + +pd.set_option("display.max_rows", 500) +pd.set_option("display.max_columns", 500) +pd.set_option("display.width", 1000) +pd.set_option("display.max_colwidth", None) # Don't truncate TrainingJobName + + +alt.data_transformers.disable_max_rows() +altair_renderer = os.getenv("ALTAIR_RENDERER", "default") +logger.info("Setting altair renderer to %s.", altair_renderer) +alt.renderers.enable(altair_renderer) + + +sm = boto3.client("sagemaker") + + +def _columnize(charts: List[alt.Chart], cols: int = 2) -> alt.VConcatChart: + """Arrange charts in columns.""" + return alt.vconcat(*[alt.hconcat(*charts[i : i + cols]) for i in range(0, len(charts), cols)]) + + +def visualize_tuning_job( + tuning_jobs: Union[str, List[str], "sagemaker.tuner.HyperparameterTuner"], + return_dfs: bool = False, + job_metrics: Optional[List[str]] = None, + trials_only: bool = False, + advanced: bool = False, +) -> Union[alt.Chart, Tuple[alt.Chart, pd.DataFrame, pd.DataFrame]]: + """Visualize SageMaker hyperparameter tuning jobs. + + Args: + tuning_jobs: Single tuning job or list of tuning jobs (name or HyperparameterTuner object) + return_dfs: Whether to return the underlying DataFrames + job_metrics: List of additional job metrics to include + trials_only: Whether to only show trials data + advanced: Whether to show advanced visualizations + + Returns: + If return_dfs is False, returns Altair chart + If return_dfs is True, returns tuple of (chart, trials_df, full_df) + """ + + trials_df, tuned_parameters, objective_name, is_minimize = get_job_analytics_data(tuning_jobs) + + try: + from IPython import get_ipython, display + + if get_ipython(): + # Running in a Jupyter Notebook + display(trials_df.head(10)) + else: + # Running in a non-Jupyter environment + logger.info(trials_df.head(10).to_string()) + except ImportError: + # Not running in a Jupyter Notebook + logger.info(trials_df.head(10).to_string()) + + full_df = _prepare_consolidated_df(trials_df) if not trials_only else pd.DataFrame() + + trials_df.columns = trials_df.columns.map(_clean_parameter_name) + full_df.columns = full_df.columns.map(_clean_parameter_name) + tuned_parameters = [_clean_parameter_name(tp) for tp in tuned_parameters] + objective_name = _clean_parameter_name(objective_name) + + charts = create_charts( + trials_df, + tuned_parameters, + full_df, + objective_name, + minimize_objective=is_minimize, + job_metrics=job_metrics, + advanced=advanced, + ) + + if return_dfs: + return charts, trials_df, full_df + return charts + + +def create_charts( + trials_df: pd.DataFrame, + tuning_parameters: List[str], + full_df: pd.DataFrame, + objective_name: str, + minimize_objective: bool, + job_metrics: Optional[List[str]] = None, + highlight_trials: bool = True, + color_trials: bool = False, + advanced: bool = False, +) -> alt.Chart: + """Create visualization charts for hyperparameter tuning results. + + Args: + trials_df: DataFrame containing trials data + tuning_parameters: List of hyperparameter names + full_df: DataFrame with consolidated data + objective_name: Name of the objective metric + minimize_objective: Whether objective should be minimized + job_metrics: Additional job metrics to include + highlight_trials: Whether to highlight selected trials + color_trials: Whether to color trials by job + advanced: Whether to show advanced visualizations + + Returns: + Altair chart visualization + """ + + if trials_df.empty: + logger.info("No results available yet.") + return pd.DataFrame() + + if job_metrics is None: + job_metrics = [] + + multiple_tuning_jobs = len(trials_df["TuningJobName"].unique()) > 1 + multiple_job_status = len(trials_df["TrainingJobStatus"].unique()) > 1 + + # Rows, n>1 + # Detail Charts + + brush = alt.selection_interval(encodings=["x"], resolve="intersect", empty=True) + + job_highlight_selection = alt.selection_point( + on="mouseover", + nearest=False, + empty=False, + fields=["TrainingJobName", "TrainingStartTime"], + ) + + # create tooltip + detail_tooltip = [] + for trp in [objective_name] + tuning_parameters: + if trials_df[trp].dtype == np.float64: + trp = alt.Tooltip(trp, format=".2e") + detail_tooltip.append(trp) + + detail_tooltip.append(alt.Tooltip("TrainingStartTime:T", format="%H:%M:%S")) + detail_tooltip.extend(["TrainingJobName", "TrainingJobStatus", "TrainingElapsedTimeSeconds"]) + + # create stroke/stroke-width for tuning_jobs + # and color for training jobs, if wanted + # add coloring of the stroke to highlight correlated + # data points + jobs_props = {"shape": alt.Shape("TrainingJobStatus:N", legend=None)} + + if multiple_tuning_jobs: + jobs_props["strokeWidth"] = alt.StrokeWidthValue(2.0) + jobs_props["stroke"] = alt.Stroke("TuningJobName:N", legend=None) + + if color_trials: + jobs_props["color"] = alt.Color("TrainingJobName:N") + + if highlight_trials: + jobs_props["strokeWidth"] = alt.condition( + job_highlight_selection, + alt.StrokeWidthValue(2.0), + alt.StrokeWidthValue(2.0), + ) + jobs_props["stroke"] = alt.condition( + job_highlight_selection, + alt.StrokeValue("gold"), + ( + alt.Stroke("TuningJobName:N", legend=None) + if multiple_tuning_jobs + else alt.StrokeValue("white") + ), + ) + + opacity = alt.condition(brush, alt.value(1.0), alt.value(0.35)) + charts = [] + + # Min and max of the objective. This is used in filtered + # charts, so that the filtering does not make the axis + # jump, which would make comparisons harder. + objective_scale = alt.Scale( + domain=( + trials_df[objective_name].min(), + trials_df[objective_name].max(), + ) + ) + + # If we have multiple tuning jobs, we also want to be able + # to discriminate based on the individual tuning job, so + # we just treat them as an additional tuning parameter + tuning_job_param = ["TuningJobName"] if multiple_tuning_jobs else [] + tuning_parameters = tuning_parameters.copy() + tuning_job_param + + # If we use early stopping and at least some jobs were + # stopped early, we want to be able to discriminate + # those jobs. + if multiple_job_status: + tuning_parameters.append("TrainingJobStatus") + + def render_detail_charts(): + # To force a tuning job to sample a combination more than once, we + # sometimes introduce a hyperparameter that has no effect. + # It's values are random and without impact, so we omit it from analysis. + ignored_parameters = {"dummy"} + for tuning_parameter in tuning_parameters: + if tuning_parameter in ignored_parameters: + continue + + # Map dataframe's dtype to altair's types and + # adjust scale if necessary + scale_type = "linear" + scale_log_base = 10 + + few_values = len(trials_df[tuning_parameter].unique()) < 8 + parameter_type = "N" # Nominal + dtype = str(trials_df.dtypes[tuning_parameter]) + if "float" in dtype: + parameter_type = "Q" # Quantitative + ratio = (trials_df[tuning_parameter].max() + 1e-10) / ( + trials_df[tuning_parameter].min() + 1e-10 + ) + not_likely_discrete = ( + len(trials_df[tuning_parameter].unique()) > trials_df[tuning_parameter].count() + ) # edge case when both are equal + if few_values and not_likely_discrete: + if ratio > 50: + scale_type = "log" + elif ratio > 10: + scale_type = "log" + scale_log_base = 2 + + elif "int" in dtype or "object" in dtype: + parameter_type = "O" # Ordinal + + x_encoding = alt.X( + f"{tuning_parameter}:{parameter_type}", + scale=alt.Scale( + zero=False, + padding=1, + type=scale_type, + base=scale_log_base, + ), + ) + + # Sync the coloring for categorical hyperparameters + discrete = parameter_type in ["O", "N"] and few_values + + # Detail Chart + charts.append( + alt.Chart(trials_df) + .add_params(brush) + .add_params(job_highlight_selection) + .mark_point(filled=True, size=50) + .encode( + x=x_encoding, + y=alt.Y( + f"{objective_name}:Q", + scale=alt.Scale(zero=False, padding=1), + axis=alt.Axis(title=objective_name), + ), + opacity=opacity, + tooltip=detail_tooltip, + **jobs_props, + ) + ) + + if discrete: + # Individually coloring the values only if we don't already + # use the colors to show the different tuning jobs + logger.info("%s, %s", parameter_type, tuning_parameter) + if not multiple_tuning_jobs: + charts[-1] = charts[-1].encode(color=f"{tuning_parameter}:N") + charts[-1] = ( + ( + charts[-1] + | alt.Chart(trials_df) + .transform_filter(brush) + .transform_density( + objective_name, + bandwidth=0.01, + groupby=[tuning_parameter], + # https://github.com/vega/altair/issues/3203#issuecomment-2141558911 + # Specifying extent no longer necessary (>5.1.2). + extent=[ + trials_df[objective_name].min(), + trials_df[objective_name].max(), + ], + ) + .mark_area(opacity=0.5) + .encode( + x=alt.X( + "value:Q", + title=objective_name, + scale=objective_scale, + ), + y="density:Q", + color=alt.Color( + f"{tuning_parameter}:N", + ), + tooltip=tuning_parameter, + ) + ).properties(title=tuning_parameter) + # .resolve_scale("independent") + # .resolve_legend(color="independent") + ) + + if advanced and parameter_type == "Q": + # Adding tick marks to the detail charts with quantitative hyperparameters + x_enc = x_encoding.copy() + charts[-1].encoding.x.title = None + charts[-1].encoding.x.axis = alt.Axis(labels=False) + + charts[-1] = charts[-1] & alt.Chart(trials_df).mark_tick(opacity=0.5).encode( + x=x_enc, + opacity=alt.condition(brush, alt.value(0.5), alt.value(0.1)), + ) + + return _columnize(charts) + + detail_charts = render_detail_charts() + + # First Row + # Progress Over Time Chart + + def render_progress_chart(): + # Sorting trials by training start time, so that we can track the \ + # progress of the best objective so far over time + trials_df_by_tst = trials_df.sort_values(["TuningJobName", "TrainingStartTime"]) + trials_df_by_tst["cum_objective"] = trials_df_by_tst.groupby(["TuningJobName"]).transform( + lambda x: x.cummin() if minimize_objective else x.cummax() + )[objective_name] + + progress_chart = ( + alt.Chart(trials_df_by_tst) + .add_params(brush) + .add_params(job_highlight_selection) + .mark_point(filled=True, size=50) + .encode( + x=alt.X("TrainingStartTime:T", scale=alt.Scale(nice=True)), + y=alt.Y( + f"{objective_name}:Q", + scale=alt.Scale(zero=False, padding=1), + axis=alt.Axis(title=objective_name), + ), + opacity=opacity, + tooltip=detail_tooltip, + **jobs_props, + ) + ) + + cum_obj_chart = ( + alt.Chart(trials_df_by_tst) + .mark_line( + interpolate="step-after", + opacity=1.0, + strokeDash=[3, 3], + strokeWidth=2.0, + ) + .encode( + x=alt.X("TrainingStartTime:T", scale=alt.Scale(nice=True)), + y=alt.Y("cum_objective:Q", scale=alt.Scale(zero=False, padding=1)), + stroke=alt.Stroke("TuningJobName:N", legend=None), + ) + ) + + if advanced: + return cum_obj_chart + progress_chart + return progress_chart + + progress_chart = render_progress_chart() + + # First Row + # KDE Training Objective + result_hist_chart = ( + alt.Chart(trials_df) + .transform_filter(brush) + .transform_density(objective_name, bandwidth=0.01) + .mark_area() + .encode( + x=alt.X("value:Q", scale=objective_scale, title=objective_name), + y="density:Q", + ) + ) + # Training Jobs + training_jobs_chart = ( + alt.Chart(trials_df.sort_values(objective_name), title="Training Jobs") + .mark_bar() + .add_params(brush) + .add_params(job_highlight_selection) + .encode( + y=alt.Y(f"{objective_name}:Q"), + x=alt.X("TrainingJobName:N", sort=None), + color=alt.Color("TrainingJobName:N"), + opacity=opacity, + **jobs_props, + ) + ) + + # Job Level Stats + + training_job_name_encodings = { + "color": alt.condition( + brush, + alt.Color("TrainingJobName:N", legend=None), + alt.value("grey"), + ), + "opacity": alt.condition(brush, alt.value(1.0), alt.value(0.3)), + "strokeWidth": alt.condition(brush, alt.value(2.5), alt.value(0.8)), + } + + duration_format = "%M:%S" + metrics_tooltip = [ + "TrainingJobName:N", + "value:Q", + "label:N", + alt.Tooltip("ts:T", format="%e:%H:%M"), + alt.Tooltip("rel_ts:T", format="%e:%H:%M"), + ] + + job_level_rows = alt.HConcatChart() + + # Use CW metrics + if not full_df.empty: + # Objective Progression + + objective_progression_chart = None + # Suppress diagram if we only have one, final, value + if ( + full_df.loc[full_df.label == objective_name] + .groupby(["TuningJobName", "TrainingJobName"])[objective_name] + .count() + .max() + > 1 + ): + objective_progression_chart = ( + alt.Chart(full_df, title=f"Progression {objective_name}", width=400) + .transform_filter(alt.FieldEqualPredicate(field="label", equal=objective_name)) + .mark_line(point=True) + .encode( + x=alt.X("rel_ts:T", axis=alt.Axis(format=duration_format)), + y=alt.Y("value:Q", scale=alt.Scale(zero=False)), + **training_job_name_encodings, + tooltip=metrics_tooltip, + ) + .interactive() + ) + + if multiple_job_status: + objective_progression_chart = objective_progression_chart.encode( + strokeDash=alt.StrokeDash("TrainingJobStatus:N", legend=None) + ) + + # Secondary chart showing the same contents, but by absolute time. + objective_progression_absolute_chart = objective_progression_chart.encode( + x=alt.X("ts:T", scale=alt.Scale(nice=True)) + ) + + objective_progression_chart = ( + objective_progression_chart | objective_progression_absolute_chart + ) + + ### + + job_metrics_charts = [] + for metric in job_metrics: + metric_chart = ( + alt.Chart(full_df, title=metric, width=400) + .transform_filter(alt.FieldEqualPredicate(field="label", equal=metric)) + .encode( + y=alt.Y("value:Q", scale=alt.Scale(zero=False)), + **training_job_name_encodings, + tooltip=metrics_tooltip, + ) + .interactive() + ) + + if ( + full_df.loc[full_df.label == metric] + .groupby(["TuningJobName", "TrainingJobName"]) + .count() + .value.max() + == 1 + ): + # single value, render as a bar over the training jobs on the x-axis + metric_chart = metric_chart.encode( + x=alt.X("TrainingJobName:N", sort=None) + ).mark_bar(interpolate="linear", point=True) + else: + # multiple values, render the values over time on the x-axis + metric_chart = metric_chart.encode( + x=alt.X("rel_ts:T", axis=alt.Axis(format=duration_format)) + ).mark_line(interpolate="linear", point=True) + + job_metrics_charts.append(metric_chart) + + job_metrics_chart = _columnize(job_metrics_charts, 3) + + # Job instance + # 'MemoryUtilization', 'CPUUtilization' + instance_metrics_chart = ( + alt.Chart(full_df, title="CPU and Memory") + .transform_filter( + alt.FieldOneOfPredicate( + field="label", + oneOf=[ + "MemoryUtilization", + "CPUUtilization", + ], + ) + ) + .mark_line() + .encode( + x=alt.X("rel_ts:T", axis=alt.Axis(format=duration_format)), + y="value:Q", + **training_job_name_encodings, + strokeDash=alt.StrokeDash("label:N", legend=alt.Legend(orient="bottom")), + tooltip=metrics_tooltip, + ) + .interactive() + ) + + if "GPUUtilization" in full_df.label.values: + instance_metrics_chart = ( + instance_metrics_chart + | alt.Chart(full_df, title="GPU and GPU Memory") + .transform_filter( + alt.FieldOneOfPredicate( + field="label", + oneOf=[ + "GPUMemoryUtilization", + "GPUUtilization", + ], + ) + ) + .mark_line() + .encode( + x=alt.X("rel_ts:T", axis=alt.Axis(format=duration_format)), + y=alt.Y("value:Q"), + **training_job_name_encodings, + strokeDash=alt.StrokeDash("label:N", legend=alt.Legend(orient="bottom")), + tooltip=metrics_tooltip, + ) + .interactive() + ) + + job_level_rows = job_metrics_chart & instance_metrics_chart + if objective_progression_chart: + job_level_rows = objective_progression_chart & job_level_rows + job_level_rows = job_level_rows.resolve_scale(strokeDash="independent").properties( + title="Job / Instance Level Metrics" + ) + + overview_row = (progress_chart | result_hist_chart).properties( + title="Hyper Parameter Tuning Job" + ) + detail_rows = detail_charts.properties(title="Hyper Parameter Details") + if job_level_rows: + job_level_rows = training_jobs_chart & job_level_rows + + return overview_row & detail_rows & job_level_rows + + +def _clean_parameter_name(s): + """Helper method to ensure proper parameter name characters for altair 5+""" + return s.replace(":", "_").replace(".", "_") + + +def _prepare_training_job_metrics(jobs): + """Fetches and combines CloudWatch metrics for multiple training jobs. + + Args: + jobs (list): List of (job_name, start_time, end_time) tuples. + + Returns: + pandas.DataFrame: Combined metrics DataFrame with 'TrainingJobName' column. + """ + df = pd.DataFrame() + for job_name, start_time, end_time in jobs: + job_df = get_cw_job_metrics( + job_name, + start_time=pd.Timestamp(start_time) - pd.DateOffset(hours=8), + end_time=pd.Timestamp(end_time) + pd.DateOffset(hours=8), + ) + if job_df is None: + logger.info("No CloudWatch metrics for %s. Skipping.", job_name) + continue + + job_df["TrainingJobName"] = job_name + df = pd.concat([df, job_df]) + return df + + +def _prepare_consolidated_df(trials_df): + """Merges training job metrics with trials data into a consolidated DataFrame.""" + if trials_df.empty: + return pd.DataFrame() + + logger.debug("Cache Hit/Miss: ", end="") + jobs_df = _prepare_training_job_metrics( + zip( + trials_df.TrainingJobName.values, + trials_df.TrainingStartTime.values, + trials_df.TrainingEndTime.values, + ) + ) + logger.info("") + + if jobs_df.empty: + return pd.DataFrame() + + merged_df = pd.merge(jobs_df, trials_df, on="TrainingJobName") + return merged_df + + +def _get_df(tuning_job_name, filter_out_stopped=False): + """Retrieves hyperparameter tuning job results and returns preprocessed DataFrame. + + Returns a DataFrame containing tuning metrics and parameters for the specified job. + """ + + tuner = sagemaker.HyperparameterTuningJobAnalytics(tuning_job_name) + + df = tuner.dataframe() + if df.empty: # HPO job just started; no results yet + return df + + df["TuningJobName"] = tuning_job_name + + # Filter out jobs without FinalObjectiveValue + df = df[df["FinalObjectiveValue"] > -float("inf")] + + # Jobs early stopped by AMT are reported with their last + # objective value, before they are stopped. + # However this value may not be a good representation + # of the eventual objective value we would have seen + # if run without stopping. Therefore it may be confusing + # to include those runs. + # For now, if included, we use a different mark to + # discriminate visually between a stopped and finished job + + if filter_out_stopped: + df = df[df["TrainingJobStatus"] != "Stopped"] + + # Preprocessing values for [32], [64] etc. + for tuning_range in tuner.tuning_ranges.values(): + parameter_name = tuning_range["Name"] + if df.dtypes[parameter_name] == "O": + try: + # Remove decorations, like [] + df[parameter_name] = df[parameter_name].apply( + lambda v: v.replace("[", "").replace("]", "").replace('"', "") + ) + + # Is it an int? 3 would work, 3.4 would fail. + try: + df[parameter_name] = df[parameter_name].astype(int) + except ValueError: + # A float then? + df[parameter_name] = df[parameter_name].astype(float) + + except (ValueError, TypeError, AttributeError): + # Catch exceptions that might occur during string manipulation or type conversion + # - ValueError: Could not convert string to float/int + # - TypeError: Object doesn't support the operation + # - AttributeError: Object doesn't have replace method + # Leaving the value untouched + pass + + return df + + +def _get_tuning_job_names_with_parents(tuning_job_names): + """Resolve dependent jobs, one level only""" + + all_tuning_job_names = [] + for tuning_job_name in tuning_job_names: + tuning_job_result = sm.describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=tuning_job_name + ) + + # find parent jobs and retrieve all tuner dataframes + parent_jobs = [] + if "WarmStartConfig" in tuning_job_result: + parent_jobs = [ + cfg["HyperParameterTuningJobName"] + for cfg in tuning_job_result["WarmStartConfig"]["ParentHyperParameterTuningJobs"] + ] + if parent_jobs: + logger.info("Tuning job %s's parents: %s", tuning_job_name, ", ".join(parent_jobs)) + all_tuning_job_names.extend([tuning_job_name, *parent_jobs]) + + # return de-duplicated tuning job names + return list(set(all_tuning_job_names)) + + +def get_job_analytics_data(tuning_job_names): + """Retrieves and processes analytics data from hyperparameter tuning jobs. + + Args: + tuning_job_names (str or list): Single tuning job name or list of names/tuner objects. + + Returns: + tuple: (DataFrame with training results, tuned params list, objective name, is_minimize). + + Raises: + ValueError: If tuning jobs have different objectives or optimization directions. + """ + if not isinstance(tuning_job_names, list): + tuning_job_names = [tuning_job_names] + + # Ensure to create a list of tuning job names (strings) + tuning_job_names = [ + ( + tuning_job.describe()["HyperParameterTuningJobName"] + if isinstance(tuning_job, sagemaker.tuner.HyperparameterTuner) + else tuning_job + ) + for tuning_job in tuning_job_names + ] + + # Maintain combined tuner dataframe from all tuning jobs + df = pd.DataFrame() + + # maintain objective, direction of optimization and tuned parameters + objective_name = None + is_minimize = None + tuned_parameters = None + + all_tuning_job_names = _get_tuning_job_names_with_parents(tuning_job_names) + + for tuning_job_name in all_tuning_job_names: + tuning_job_result = sm.describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=tuning_job_name + ) + status = tuning_job_result["HyperParameterTuningJobStatus"] + logger.info("Tuning job %-25s status: %s", tuning_job_name, status) + + df = pd.concat([df, _get_df(tuning_job_name)]) + + # maintain objective and assure that all tuning jobs use the same + job_is_minimize = ( + tuning_job_result["HyperParameterTuningJobConfig"]["HyperParameterTuningJobObjective"][ + "Type" + ] + != "Maximize" + ) + job_objective_name = tuning_job_result["HyperParameterTuningJobConfig"][ + "HyperParameterTuningJobObjective" + ]["MetricName"] + job_tuned_parameters = [ + v["Name"] + for v in sagemaker.HyperparameterTuningJobAnalytics( + tuning_job_name + ).tuning_ranges.values() + ] + + if not objective_name: + objective_name = job_objective_name + is_minimize = job_is_minimize + tuned_parameters = job_tuned_parameters + else: + if ( + objective_name != job_objective_name + or is_minimize != job_is_minimize + or set(tuned_parameters) != set(job_tuned_parameters) + ): + raise ValueError( + "All tuning jobs must use the same objective and optimization direction." + ) + + if not df.empty: + # Cleanup wrongly encoded floats, e.g. containing quotes. + for i, dtype in enumerate(df.dtypes): + column_name = str(df.columns[i]) + if column_name in [ + "TrainingJobName", + "TrainingJobStatus", + "TuningJobName", + ]: + continue + if dtype == "object": + val = df[column_name].iloc[0] + if isinstance(val, str) and val.startswith('"'): + try: + df[column_name] = df[column_name].apply(lambda x: int(x.replace('"', ""))) + except (ValueError, TypeError, AttributeError): + # noqa: E722 nosec b110 if we fail, we just continue with what we had + pass # Value is not an int, but a string + + df = df.sort_values("FinalObjectiveValue", ascending=is_minimize) + df[objective_name] = df.pop("FinalObjectiveValue") + + # Fix potential issue with dates represented as objects, instead of a timestamp + # This can in other cases lead to: + # https://www.markhneedham.com/blog/2020/01/10/altair-typeerror-object-type- + # date-not-json-serializable/ + # Seen this for TrainingEndTime, but will watch TrainingStartTime as well now. + df["TrainingEndTime"] = pd.to_datetime(df["TrainingEndTime"]) + df["TrainingStartTime"] = pd.to_datetime(df["TrainingStartTime"]) + + logger.info("") + logger.info("Number of training jobs with valid objective: %d", len(df)) + logger.info("Lowest: %s Highest %s", min(df[objective_name]), max(df[objective_name])) + + tuned_parameters = [_clean_parameter_name(tp) for tp in tuned_parameters] + + return df, tuned_parameters, objective_name, is_minimize diff --git a/src/sagemaker/apiutils/_base_types.py b/src/sagemaker/apiutils/_base_types.py index acee3d4d67..eb1555f109 100644 --- a/src/sagemaker/apiutils/_base_types.py +++ b/src/sagemaker/apiutils/_base_types.py @@ -123,7 +123,7 @@ def _list( boto_list_items_name, boto_next_token_name="NextToken", sagemaker_session=None, - **kwargs + **kwargs, ): """List objects from the SageMaker API.""" sagemaker_session = sagemaker_session or _utils.default_session() @@ -154,7 +154,7 @@ def _search( search_item_factory, boto_next_token_name="NextToken", sagemaker_session=None, - **kwargs + **kwargs, ): """Search for objects with the SageMaker API.""" sagemaker_session = sagemaker_session or _utils.default_session() diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index bb4059c03a..e18d7ba2b9 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -478,7 +478,7 @@ def create_model( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + Callable[[string, sagemaker.session.Session], Any]: A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -591,7 +591,7 @@ def deploy( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -609,7 +609,7 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: - callable[string, sagemaker.session.Session] or ``None``: + Optional[Callable[[string, sagemaker.session.Session], Any]]: If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on the created endpoint name. Otherwise, ``None``. """ diff --git a/src/sagemaker/automl/automlv2.py b/src/sagemaker/automl/automlv2.py index 0819e5384e..b071be3b24 100644 --- a/src/sagemaker/automl/automlv2.py +++ b/src/sagemaker/automl/automlv2.py @@ -1022,7 +1022,7 @@ def create_model( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -1130,7 +1130,7 @@ def deploy( training cluster for distributed training. Default: False model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -1148,7 +1148,7 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: - callable[string, sagemaker.session.Session] or ``None``: + Optional[Callable[[string, sagemaker.session.Session], Any]]: If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on the created endpoint name. Otherwise, ``None``. """ diff --git a/src/sagemaker/aws_batch/__init__.py b/src/sagemaker/aws_batch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/aws_batch/batch_api_helper.py b/src/sagemaker/aws_batch/batch_api_helper.py new file mode 100644 index 0000000000..4482a644ab --- /dev/null +++ b/src/sagemaker/aws_batch/batch_api_helper.py @@ -0,0 +1,186 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The module provides helper function for Batch Submit/Describe/Terminal job APIs.""" +from __future__ import absolute_import + +import json +from typing import List, Dict, Optional +from sagemaker.aws_batch.constants import ( + SAGEMAKER_TRAINING, + DEFAULT_TIMEOUT, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, +) +from sagemaker.aws_batch.boto_client import get_batch_boto_client + + +def submit_service_job( + training_payload: Dict, + job_name: str, + job_queue: str, + retry_config: Optional[Dict] = None, + scheduling_priority: Optional[int] = None, + timeout: Optional[Dict] = None, + share_identifier: Optional[str] = None, + tags: Optional[Dict] = None, +) -> Dict: + """Batch submit_service_job API helper function. + + Args: + training_payload: a dict containing a dict of arguments for Training job. + job_name: Batch job name. + job_queue: Batch job queue ARN. + retry_config: Batch job retry configuration. + scheduling_priority: An integer representing scheduling priority. + timeout: Set with value of timeout if specified, else default to 1 day. + share_identifier: value of shareIdentifier if specified. + tags: A dict of string to string representing Batch tags. + + Returns: + A dict containing jobArn, jobName and jobId. + """ + if timeout is None: + timeout = DEFAULT_TIMEOUT + client = get_batch_boto_client() + training_payload_tags = training_payload.pop("Tags", None) + payload = { + "jobName": job_name, + "jobQueue": job_queue, + "retryStrategy": DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + "serviceJobType": SAGEMAKER_TRAINING, + "serviceRequestPayload": json.dumps(training_payload), + "timeoutConfig": timeout, + } + if retry_config: + payload["retryStrategy"] = retry_config + if scheduling_priority: + payload["schedulingPriority"] = scheduling_priority + if share_identifier: + payload["shareIdentifier"] = share_identifier + if tags or training_payload_tags: + payload["tags"] = __merge_tags(tags, training_payload_tags) + return client.submit_service_job(**payload) + + +def describe_service_job(job_id: str) -> Dict: + """Batch describe_service_job API helper function. + + Args: + job_id: Job ID used. + + Returns: a dict. See the sample below + { + 'attempts': [ + { + 'serviceResourceId': { + 'name': 'string', + 'value': 'string' + }, + 'startedAt': 123, + 'stoppedAt': 123, + 'statusReason': 'string' + }, + ], + 'createdAt': 123, + 'isTerminated': True|False, + 'jobArn': 'string', + 'jobId': 'string', + 'jobName': 'string', + 'jobQueue': 'string', + 'retryStrategy': { + 'attempts': 123 + }, + 'schedulingPriority': 123, + 'serviceRequestPayload': 'string', + 'serviceJobType': 'EKS'|'ECS'|'ECS_FARGATE'|'SAGEMAKER_TRAINING', + 'shareIdentifier': 'string', + 'startedAt': 123, + 'status': 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED', + 'statusReason': 'string', + 'stoppedAt': 123, + 'tags': { + 'string': 'string' + }, + 'timeout': { + 'attemptDurationSeconds': 123 + } + } + """ + client = get_batch_boto_client() + return client.describe_service_job(jobId=job_id) + + +def terminate_service_job(job_id: str, reason: Optional[str] = "default terminate reason") -> Dict: + """Batch terminate_service_job API helper function. + + Args: + job_id: Job ID + reason: A string representing terminate reason. + + Returns: an empty dict + """ + client = get_batch_boto_client() + return client.terminate_service_job(jobId=job_id, reason=reason) + + +def list_service_job( + job_queue: str, + job_status: Optional[str] = None, + filters: Optional[List] = None, + next_token: Optional[str] = None, +) -> Dict: + """Batch list_service_job API helper function. + + Args: + job_queue: Batch job queue ARN. + job_status: Batch job status. + filters: A list of Dict. Each contains a filter. + next_token: Used to retrieve data in next page. + + Returns: A generator containing list results. + + """ + client = get_batch_boto_client() + payload = {"jobQueue": job_queue} + if filters: + payload["filters"] = filters + if next_token: + payload["nextToken"] = next_token + if job_status: + payload["jobStatus"] = job_status + part_of_jobs = client.list_service_jobs(**payload) + next_token = part_of_jobs.get("nextToken") + yield part_of_jobs + if next_token: + yield from list_service_job(job_queue, job_status, filters, next_token) + + +def __merge_tags(batch_tags: Optional[Dict], training_tags: Optional[List]) -> Optional[Dict]: + """Merges Batch and training payload tags. + + Returns a copy of Batch tags merged with training payload tags. Training payload tags take + precedence in the case of key conflicts. + + :param batch_tags: A dict of string to string representing Batch tags. + :param training_tags: A list of `{"Key": "string", "Value": "string"}` objects representing + training payload tags. + :return: A dict of string to string representing batch tags merged with training tags. + batch_tags is returned unaltered if training_tags is None or empty. + """ + if not training_tags: + return batch_tags + + training_tags_to_merge = {tag["Key"]: tag["Value"] for tag in training_tags} + batch_tags_copy = batch_tags.copy() if batch_tags else {} + batch_tags_copy.update(training_tags_to_merge) + + return batch_tags_copy diff --git a/src/sagemaker/aws_batch/boto_client.py b/src/sagemaker/aws_batch/boto_client.py new file mode 100644 index 0000000000..87f3486887 --- /dev/null +++ b/src/sagemaker/aws_batch/boto_client.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file provides helper function for getting Batch boto client.""" +from __future__ import absolute_import + +from typing import Optional +import boto3 + + +def get_batch_boto_client( + region: Optional[str] = None, + endpoint: Optional[str] = None, +) -> boto3.session.Session.client: + """Helper function for getting Batch boto3 client. + + Args: + region: Region specified + endpoint: Batch API endpoint. + + Returns: Batch boto3 client. + + """ + return boto3.client("batch", region_name=region, endpoint_url=endpoint) diff --git a/src/sagemaker/aws_batch/constants.py b/src/sagemaker/aws_batch/constants.py new file mode 100644 index 0000000000..ee41d3a413 --- /dev/null +++ b/src/sagemaker/aws_batch/constants.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file defines constants used for Batch API helper functions.""" + +from __future__ import absolute_import + +SAGEMAKER_TRAINING = "SAGEMAKER_TRAINING" +DEFAULT_ATTEMPT_DURATION_IN_SECONDS = 86400 # 1 day in seconds. +DEFAULT_TIMEOUT = {"attemptDurationSeconds": DEFAULT_ATTEMPT_DURATION_IN_SECONDS} +POLL_IN_SECONDS = 5 +JOB_STATUS_RUNNING = "RUNNING" +JOB_STATUS_COMPLETED = "SUCCEEDED" +JOB_STATUS_FAILED = "FAILED" +DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG = { + "attempts": 1, + "evaluateOnExit": [ + { + "action": "RETRY", + "onStatusReason": "Received status from SageMaker:InternalServerError: " + "We encountered an internal error. Please try again.", + }, + {"action": "EXIT", "onStatusReason": "*"}, + ], +} diff --git a/src/sagemaker/aws_batch/exception.py b/src/sagemaker/aws_batch/exception.py new file mode 100644 index 0000000000..94318bbce4 --- /dev/null +++ b/src/sagemaker/aws_batch/exception.py @@ -0,0 +1,52 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""The file Defines customized exception for Batch queueing""" +from __future__ import absolute_import + + +class NoTrainingJob(Exception): + """Define NoTrainingJob Exception. + + It means no Training job has been created by AWS Batch service. + """ + + def __init__(self, value): + super().__init__(value) + self.value = value + + def __str__(self): + """Convert Exception to string. + + Returns: a String containing exception error messages. + + """ + return repr(self.value) + + +class MissingRequiredArgument(Exception): + """Define MissingRequiredArgument exception. + + It means some required arguments are missing. + """ + + def __init__(self, value): + super().__init__(value) + self.value = value + + def __str__(self): + """Convert Exception to string. + + Returns: a String containing exception error messages. + + """ + return repr(self.value) diff --git a/src/sagemaker/aws_batch/training_queue.py b/src/sagemaker/aws_batch/training_queue.py new file mode 100644 index 0000000000..b540fad0a9 --- /dev/null +++ b/src/sagemaker/aws_batch/training_queue.py @@ -0,0 +1,212 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Define Queue class for AWS Batch service""" +from __future__ import absolute_import + +from typing import Dict, Optional, List, Union +import logging +from sagemaker.estimator import EstimatorBase, _TrainingJob +from sagemaker.modules.train.model_trainer import ModelTrainer, Mode +from .training_queued_job import TrainingQueuedJob +from .batch_api_helper import submit_service_job, list_service_job +from .exception import MissingRequiredArgument +from .constants import DEFAULT_TIMEOUT, JOB_STATUS_RUNNING + + +class TrainingQueue: + """TrainingQueue class for AWS Batch service + + With this class, customers are able to create a new queue and submit jobs to AWS Batch Service. + """ + + def __init__(self, queue_name: str): + self.queue_name = queue_name + + def submit( + self, + training_job: Union[EstimatorBase, ModelTrainer], + inputs, + job_name: Optional[str] = None, + retry_config: Optional[Dict] = None, + priority: Optional[int] = None, + share_identifier: Optional[str] = None, + timeout: Optional[Dict] = None, + tags: Optional[Dict] = None, + experiment_config: Optional[Dict] = None, + ) -> TrainingQueuedJob: + """Submit a queued job and return a QueuedJob object. + + Args: + training_job: Training job EstimatorBase or ModelTrainer object. + inputs: Training job inputs. + job_name: Batch job name. + retry_config: Retry configuration for Batch job. + priority: Scheduling priority for Batch job. + share_identifier: Share identifier for Batch job. + timeout: Timeout configuration for Batch job. + tags: Tags apply to Batch job. These tags are for Batch job only. + experiment_config: Experiment management configuration. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. + + Returns: a TrainingQueuedJob object with Batch job ARN and job name. + + """ + if not isinstance(training_job, (EstimatorBase, ModelTrainer)): + raise TypeError( + "training_job must be an instance of EstimatorBase or ModelTrainer, " + f"but got {type(training_job)}" + ) + + training_payload = {} + if isinstance(training_job, EstimatorBase): + if experiment_config is None: + experiment_config = {} + training_job.prepare_workflow_for_training(job_name) + training_args = _TrainingJob.get_train_args(training_job, inputs, experiment_config) + training_payload = training_job.sagemaker_session.get_train_request(**training_args) + else: + if training_job.training_mode != Mode.SAGEMAKER_TRAINING_JOB: + raise ValueError( + "TrainingQueue requires using a ModelTrainer with Mode.SAGEMAKER_TRAINING_JOB" + ) + if experiment_config is not None: + logging.warning( + "ExperimentConfig is not supported for ModelTrainer. " + "It will be ignored when submitting the job." + ) + training_payload = training_job._create_training_job_args( + input_data_config=inputs, boto3=True + ) + + if timeout is None: + timeout = DEFAULT_TIMEOUT + if job_name is None: + job_name = training_payload["TrainingJobName"] + + resp = submit_service_job( + training_payload, + job_name, + self.queue_name, + retry_config, + priority, + timeout, + share_identifier, + tags, + ) + if "jobArn" not in resp or "jobName" not in resp: + raise MissingRequiredArgument( + "jobArn or jobName is missing in response from Batch submit_service_job API" + ) + return TrainingQueuedJob(resp["jobArn"], resp["jobName"]) + + def map( + self, + training_job: Union[EstimatorBase, ModelTrainer], + inputs, + job_names: Optional[List[str]] = None, + retry_config: Optional[Dict] = None, + priority: Optional[int] = None, + share_identifier: Optional[str] = None, + timeout: Optional[Dict] = None, + tags: Optional[Dict] = None, + experiment_config: Optional[Dict] = None, + ) -> List[TrainingQueuedJob]: + """Submit queued jobs to the provided estimator and return a list of TrainingQueuedJob objects. + + Args: + training_job: Training job EstimatorBase or ModelTrainer object. + inputs: List of Training job inputs. + job_names: List of Batch job names. + retry_config: Retry config for the Batch jobs. + priority: Scheduling priority for the Batch jobs. + share_identifier: Share identifier for the Batch jobs. + timeout: Timeout configuration for the Batch jobs. + tags: Tags apply to Batch job. These tags are for Batch job only. + experiment_config: Experiment management configuration. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. + + Returns: a list of TrainingQueuedJob objects with each Batch job ARN and job name. + + """ + if experiment_config is None: + experiment_config = {} + + if job_names is not None: + if len(job_names) != len(inputs): + raise ValueError( + "When specified, the number of job names must match the number of inputs" + ) + else: + job_names = [None] * len(inputs) + + queued_batch_job_list = [] + for index, value in enumerate(inputs): + queued_batch_job = self.submit( + training_job, + value, + job_names[index], + retry_config, + priority, + share_identifier, + timeout, + tags, + experiment_config, + ) + queued_batch_job_list.append(queued_batch_job) + + return queued_batch_job_list + + def list_jobs( + self, job_name: Optional[str] = None, status: Optional[str] = JOB_STATUS_RUNNING + ) -> List[TrainingQueuedJob]: + """List Batch jobs according to job_name or status. + + Args: + job_name: Batch job name. + status: Batch job status. + + Returns: A list of QueuedJob. + + """ + filters = None + if job_name: + filters = [{"name": "JOB_NAME", "values": [job_name]}] + status = None # job_status is ignored when job_name is specified. + jobs_to_return = [] + next_token = None + for job_result_dict in list_service_job(self.queue_name, status, filters, next_token): + for job_result in job_result_dict.get("jobSummaryList", []): + if "jobArn" in job_result and "jobName" in job_result: + jobs_to_return.append( + TrainingQueuedJob(job_result["jobArn"], job_result["jobName"]) + ) + else: + logging.warning("Missing JobArn or JobName in Batch ListJobs API") + continue + return jobs_to_return + + def get_job(self, job_name): + """Get a Batch job according to job_name. + + Args: + job_name: Batch job name. + + Returns: The QueuedJob with name matching job_name. + + """ + jobs_to_return = self.list_jobs(job_name) + if len(jobs_to_return) == 0: + raise ValueError(f"Cannot find job: {job_name}") + return jobs_to_return[0] diff --git a/src/sagemaker/aws_batch/training_queued_job.py b/src/sagemaker/aws_batch/training_queued_job.py new file mode 100644 index 0000000000..6bb42c3c61 --- /dev/null +++ b/src/sagemaker/aws_batch/training_queued_job.py @@ -0,0 +1,217 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Define QueuedJob class for AWS Batch service""" +from __future__ import absolute_import + +import logging +import time +import asyncio +from typing import Optional, Dict +import nest_asyncio +from sagemaker.estimator import Estimator +from .batch_api_helper import terminate_service_job, describe_service_job +from .exception import NoTrainingJob, MissingRequiredArgument +from ..utils import get_training_job_name_from_training_job_arn +from .constants import JOB_STATUS_COMPLETED, JOB_STATUS_FAILED, POLL_IN_SECONDS + +logging.basicConfig( + format="%(asctime)s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" +) + + +class TrainingQueuedJob: + """TrainingQueuedJob class for AWS Batch service. + + With this class, customers are able to attach the latest training job to an estimator. + """ + + def __init__(self, job_arn: str, job_name: str): + self.job_arn = job_arn + self.job_name = job_name + self._no_training_job_status = {"SUBMITTED", "PENDING", "RUNNABLE"} + + def get_estimator(self) -> Estimator: + """Attach the latest training job to an estimator and return. + + Returns: an Estimator instance. + + """ + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + if self._training_job_created(job_status): + if "latestAttempt" not in describe_resp: + raise MissingRequiredArgument("No LatestAttempt in describe call") + new_training_job_name = _get_new_training_job_name_from_latest_attempt( + describe_resp["latestAttempt"] + ) + output_estimator = _construct_estimator_from_training_job_name(new_training_job_name) + _remove_system_tags_in_place_in_estimator_object(output_estimator) + return output_estimator + + _output_attempt_history(describe_resp) + raise NoTrainingJob("No Training job created. Job is still waiting in queue") + + def terminate(self, reason: Optional[str] = "Default terminate reason") -> None: + """Terminate Batch job. + + Args: + reason: Reason for terminating a job. + + Returns: None + + """ + terminate_service_job(self.job_arn, reason) + + def describe(self) -> Dict: + """Describe Batch job. + + Returns: A dict which includes job parameters, job status, attempts and so on. + + """ + return describe_service_job(self.job_arn) + + def _training_job_created(self, status: str) -> bool: + """Return True if a Training job has been created + + Args: + status: Job status returned from Batch API. + + Returns: a boolean indicating whether a Training job has been created. + + """ + return status not in self._no_training_job_status + + def result(self, timeout: int = None) -> Dict: + """Fetch the terminal result of the Batch job. + + Args: + timeout: The time to wait for the Batch job to complete. Defaults to ``None``. + + Returns: The results of the Batch job, represented as a Dict. + + """ + nest_asyncio.apply() + loop = asyncio.get_event_loop() + task = loop.create_task(self.fetch_job_results(timeout)) + resp = loop.run_until_complete(task) + return resp + + async def fetch_job_results(self, timeout: int = None) -> Dict: + """Async method that waits for the Batch job to complete or until timeout. + + Args: + timeout: The time to wait for the Batch job to complete. Defaults to ``None``. + + Returns: The results of the Batch job, represented as a Dict, or an Error. + + """ + self.wait(timeout) + + describe_resp = self.describe() + if describe_resp.get("status", "") == JOB_STATUS_COMPLETED: + return describe_resp + if describe_resp.get("status", "") == JOB_STATUS_FAILED: + raise RuntimeError(describe_resp["statusReason"]) + raise TimeoutError("Reached timeout before the Batch job reached a terminal status") + + def wait(self, timeout: int = None) -> Dict: + """Wait for the Batch job to finish. + + This method blocks on the job completing for up to the timeout value (if specified). + If timeout is ``None``, this method will block until the job is completed. + + Args: + timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by + default. + + Returns: The last describe_service_job response for the Batch job. + """ + request_end_time = time.time() + timeout if timeout else None + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + job_completed = job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED) + + while not job_completed: + if timeout and time.time() > request_end_time: + logging.info( + "Timeout exceeded: %d seconds elapsed. Returning current results", timeout + ) + break + if job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED): + break + + time.sleep(POLL_IN_SECONDS) + describe_resp = self.describe() + job_status = describe_resp.get("status", "") + job_completed = job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED) + + return describe_resp + + +def _construct_estimator_from_training_job_name(training_job_name: str) -> Estimator: + """Build Estimator instance from payload. + + Args: + training_job_name: Training job name. + + Returns: an Estimator instance. + + """ + return Estimator.attach(training_job_name) + + +def _output_attempt_history(describe_resp: Dict) -> None: + """Print attempt history if no Training job created. + + Args: + describe_resp: Describe response from Batch API. + + Returns: None + + """ + has_seen_status_reason = False + for i, attempt_dict in enumerate(describe_resp.get("attempts", [])): + if "statusReason" in attempt_dict: + logging.info("Attempt %d - %s", i + 1, attempt_dict["statusReason"]) + has_seen_status_reason = True + if not has_seen_status_reason: + logging.info("No attempts found or no statusReason found.") + + +def _get_new_training_job_name_from_latest_attempt(latest_attempt: Dict) -> str: + """Extract new Training job name from latest attempt in Batch Describe response. + + Args: + latest_attempt: a Dict containing Training job arn. + + Returns: new Training job name or None if not found. + + """ + training_job_arn = latest_attempt.get("serviceResourceId", {}).get("value", None) + return get_training_job_name_from_training_job_arn(training_job_arn) + + +def _remove_system_tags_in_place_in_estimator_object(estimator: Estimator) -> None: + """Remove system tags in place. + + Args: + estimator: input Estimator object. + + Returns: None. Remove system tags in place. + + """ + new_tags = [] + for tag_dict in estimator.tags: + if not tag_dict.get("Key", "").startswith("aws:"): + new_tags.append(tag_dict) + estimator.tags = new_tags diff --git a/src/sagemaker/base_deserializers.py b/src/sagemaker/base_deserializers.py index a152f0144d..ded68fc4b0 100644 --- a/src/sagemaker/base_deserializers.py +++ b/src/sagemaker/base_deserializers.py @@ -23,6 +23,7 @@ import numpy as np from six import with_metaclass +from sagemaker.serializer_utils import read_records from sagemaker.utils import DeferredError try: @@ -388,3 +389,31 @@ def deserialize(self, stream, content_type="tensor/pt"): "Unable to deserialize your data to torch.Tensor.\ Please provide custom deserializer in InferenceSpec." ) + + +class RecordDeserializer(SimpleBaseDeserializer): + """Deserialize RecordIO Protobuf data from an inference endpoint.""" + + def __init__(self, accept="application/x-recordio-protobuf"): + """Initialize a ``RecordDeserializer`` instance. + + Args: + accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that + is expected from the inference endpoint (default: + "application/x-recordio-protobuf"). + """ + super(RecordDeserializer, self).__init__(accept=accept) + + def deserialize(self, data, content_type): + """Deserialize RecordIO Protobuf data from an inference endpoint. + + Args: + data (object): The protobuf message to deserialize. + content_type (str): The MIME type of the data. + Returns: + list: A list of records. + """ + try: + return read_records(data) + finally: + data.close() diff --git a/src/sagemaker/base_predictor.py b/src/sagemaker/base_predictor.py index 1a7eea9cd7..a9b2cb021d 100644 --- a/src/sagemaker/base_predictor.py +++ b/src/sagemaker/base_predictor.py @@ -430,6 +430,8 @@ def update_endpoint( - If ``initial_instance_count``, ``instance_type``, or ``accelerator_type`` is specified and either ``model_name`` is ``None`` or there are multiple models associated with the endpoint. + botocore.exceptions.ClientError: If SageMaker throws an error while creating + endpoint config, describing endpoint or updating endpoint """ production_variants = None current_model_names = self._get_model_names() diff --git a/src/sagemaker/base_serializers.py b/src/sagemaker/base_serializers.py index 45fea23493..0e1df120ff 100644 --- a/src/sagemaker/base_serializers.py +++ b/src/sagemaker/base_serializers.py @@ -22,6 +22,7 @@ from pandas import DataFrame from six import with_metaclass +from sagemaker.serializer_utils import write_numpy_to_dense_tensor from sagemaker.utils import DeferredError try: @@ -466,3 +467,39 @@ def serialize(self, data): ) raise ValueError("Object of type %s is not a torch.Tensor" % type(data)) + + +class RecordSerializer(SimpleBaseSerializer): + """Serialize a NumPy array for an inference request.""" + + def __init__(self, content_type="application/x-recordio-protobuf"): + """Initialize a ``RecordSerializer`` instance. + + Args: + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "application/x-recordio-protobuf"). + """ + super(RecordSerializer, self).__init__(content_type=content_type) + + def serialize(self, data): + """Serialize a NumPy array into a buffer containing RecordIO records. + + Args: + data (numpy.ndarray): The data to serialize. + + Returns: + io.BytesIO: A buffer containing the data serialized as records. + """ + if len(data.shape) == 1: + data = data.reshape(1, data.shape[0]) + + if len(data.shape) != 2: + raise ValueError( + "Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape) + ) + + buffer = io.BytesIO() + write_numpy_to_dense_tensor(buffer, data) + buffer.seek(0) + + return buffer diff --git a/src/sagemaker/batch_inference/__init__.py b/src/sagemaker/batch_inference/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/batch_inference/batch_transform_inference_config.py b/src/sagemaker/batch_inference/batch_transform_inference_config.py new file mode 100644 index 0000000000..3d3618d7fb --- /dev/null +++ b/src/sagemaker/batch_inference/batch_transform_inference_config.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Config Classes for taking in parameters for Batch Inference""" + +from __future__ import absolute_import +from pydantic import BaseModel + + +class BatchTransformInferenceConfig(BaseModel): + """Config class for Batch Transform Inference + + * Can be used to deploy from ModelBuilder + """ + + instance_count: int + instance_type: str + output_path: str diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index 09addf9910..f493c10846 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -54,7 +54,7 @@ def __init__( framework_version: Optional[str] = None, py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - **kwargs + **kwargs, ): """This ``Estimator`` executes an Chainer script in a managed execution environment. @@ -173,7 +173,7 @@ def create_model( entry_point=None, source_dir=None, dependencies=None, - **kwargs + **kwargs, ): """Create a SageMaker ``ChainerModel`` object that can be deployed to an ``Endpoint``. @@ -225,7 +225,7 @@ def create_model( sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), dependencies=(dependencies or self.dependencies), - **kwargs + **kwargs, ) @classmethod diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 9fce051454..c2d2187b69 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -28,11 +28,16 @@ from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.chainer import defaults from sagemaker.deserializers import NumpyDeserializer +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer from sagemaker.utils import to_string from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -91,9 +96,9 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, py_version: Optional[str] = None, - predictor_cls: callable = ChainerPredictor, + predictor_cls: Optional[Callable] = ChainerPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, - **kwargs + **kwargs, ): """Initialize an ChainerModel. @@ -120,7 +125,7 @@ def __init__( py_version (str): Python version you want to use for executing your model training code. Defaults to ``None``. Required unless ``image_uri`` is provided. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -175,6 +180,8 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -226,6 +233,9 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -266,6 +276,8 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + model_life_cycle=model_life_cycle, ) def prepare_container_def( @@ -274,6 +286,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition with framework configuration set in model environment. @@ -325,6 +338,7 @@ def prepare_container_def( self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 246cdbcc2d..d9b9a3021c 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -870,7 +870,7 @@ class BiasConfig: def __init__( self, - label_values_or_threshold: Union[int, float, str], + label_values_or_threshold: List[Union[int, float, str]], facet_name: Union[str, int, List[str], List[int]], facet_values_or_threshold: Optional[Union[int, float, str]] = None, group_name: Optional[str] = None, diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py index 0e2aabbec4..54bccba55e 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py @@ -51,8 +51,8 @@ "StreamDeserializer": ("sagemaker.deserializers",), "NumpyDeserializer": ("sagemaker.deserializers",), "JSONDeserializer": ("sagemaker.deserializers",), - "RecordSerializer ": ("sagemaker.amazon.common",), - "RecordDeserializer": ("sagemaker.amazon.common",), + "RecordSerializer ": ("sagemaker.serializers",), + "RecordDeserializer": ("sagemaker.deserializers",), } OLD_CLASS_NAME_TO_NEW_CLASS_NAME = { @@ -101,8 +101,8 @@ def node_should_be_modified(self, node): - ``sagemaker.predictor.StreamDeserializer`` - ``sagemaker.predictor._NumpyDeserializer`` - ``sagemaker.predictor._JsonDeserializer`` - - ``sagemaker.amazon.common.numpy_to_record_serializer`` - - ``sagemaker.amazon.common.record_deserializer`` + - ``sagemaker.serializers.numpy_to_record_serializer`` + - ``sagemaker.deserializers.record_deserializer`` Args: node (ast.Call): a node that represents a function call. For more, @@ -128,8 +128,8 @@ def modify_node(self, node): - ``sagemaker.deserializers.StreamDeserializer`` - ``sagemaker.deserializers.NumpyDeserializer`` - ``sagemaker.deserializers._JsonDeserializer`` - - ``sagemaker.amazon.common.RecordSerializer`` - - ``sagemaker.amazon.common.RecordDeserializer`` + - ``sagemaker.serializers.RecordSerializer`` + - ``sagemaker.deserializers.RecordDeserializer`` Args: node (ast.Call): a node that represents a SerDe constructor. @@ -303,8 +303,8 @@ def node_should_be_modified(self, node): """Checks if the import statement imports a SerDe from the ``sagemaker.amazon.common``. This checks for: - - ``sagemaker.amazon.common.numpy_to_record_serializer`` - - ``sagemaker.amazon.common.record_deserializer`` + - ``sagemaker.serializers.numpy_to_record_serializer`` + - ``sagemaker.deserializers.record_deserializer`` Args: node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. @@ -322,8 +322,8 @@ def modify_node(self, node): """Upgrades the ``numpy_to_record_serializer`` and ``record_deserializer`` imports. This upgrades the classes to (if applicable): - - ``sagemaker.amazon.common.RecordSerializer`` - - ``sagemaker.amazon.common.RecordDeserializer`` + - ``sagemaker.serializers.RecordSerializer`` + - ``sagemaker.deserializers.RecordDeserializer`` Args: node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. diff --git a/src/sagemaker/config/config_schema.py b/src/sagemaker/config/config_schema.py index 35c4859930..61da17e7cf 100644 --- a/src/sagemaker/config/config_schema.py +++ b/src/sagemaker/config/config_schema.py @@ -116,6 +116,7 @@ REGION_NAME = "region_name" TELEMETRY_OPT_OUT = "TelemetryOptOut" NOTEBOOK_JOB = "NotebookJob" +MODEL_TRAINER = "ModelTrainer" def _simple_path(*args: str): @@ -142,6 +143,7 @@ def _simple_path(*args: str): ) TRAINING_JOB_ROLE_ARN_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, ROLE_ARN) TRAINING_JOB_VPC_CONFIG_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, VPC_CONFIG) +TRAINING_JOB_TAGS_PATH = _simple_path(SAGEMAKER, TRAINING_JOB, TAGS) TRAINING_JOB_SECURITY_GROUP_IDS_PATH = _simple_path( TRAINING_JOB_VPC_CONFIG_PATH, SECURITY_GROUP_IDS ) @@ -538,7 +540,8 @@ def _simple_path(*args: str): "minItems": 0, "maxItems": 50, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment "environmentVariables": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -551,13 +554,15 @@ def _simple_path(*args: str): }, "maxProperties": 48, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri "s3Uri": { TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint "preExecutionCommand": {TYPE: "string", "pattern": r".*"}, # Regex based on https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_PipelineDefinitionS3Location.html # except with an additional ^ and $ for the beginning and the end to closer align to @@ -568,7 +573,8 @@ def _simple_path(*args: str): "minLength": 3, "maxLength": 63, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_MonitoringJobDefinition.html#sagemaker-Type-MonitoringJobDefinition-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_MonitoringJobDefinition.html#sagemaker-Type-MonitoringJobDefinition-Environment "environment-Length256-Properties50": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -581,7 +587,8 @@ def _simple_path(*args: str): }, "maxProperties": 50, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTransformJob.html#sagemaker-CreateTransformJob-request-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_CreateTransformJob.html#sagemaker-CreateTransformJob-request-Environment "environment-Length10240-Properties16": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -594,7 +601,8 @@ def _simple_path(*args: str): }, "maxProperties": 16, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ContainerDefinition.html#sagemaker-Type-ContainerDefinition-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_ContainerDefinition.html#sagemaker-Type-ContainerDefinition-Environment "environment-Length1024-Properties16": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -607,7 +615,8 @@ def _simple_path(*args: str): }, "maxProperties": 16, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html#sagemaker-CreateProcessingJob-request-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_CreateProcessingJob.html#sagemaker-CreateProcessingJob-request-Environment "environment-Length256-Properties100": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -620,7 +629,8 @@ def _simple_path(*args: str): }, "maxProperties": 100, }, - # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment + # Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/ + # API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment "environment-Length512-Properties48": { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, @@ -656,6 +666,25 @@ def _simple_path(*args: str): "minItems": 1, "maxItems": 15, }, + "role": { + TYPE: "string", + "pattern": r"^arn:aws[a-z\-]*:iam::\d{12}:role/?[a-zA-Z_0-9+=,.@\-_/]+$", + "minLength": 20, + "maxLength": 2048, + }, + "baseJobName": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "sourceCode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "distributed": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "compute": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "networking": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "stoppingCondition": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "trainingImage": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "trainingImageConfig": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "algorithmName": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "outputDataConfig": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "trainingInputMode": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "environment": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, + "hyperparameters": {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, }, PROPERTIES: { SCHEMA_VERSION: { @@ -709,6 +738,7 @@ def _simple_path(*args: str): }, }, }, + MODEL_TRAINER: {TYPE: OBJECT, ADDITIONAL_PROPERTIES: True}, ESTIMATOR: { TYPE: OBJECT, ADDITIONAL_PROPERTIES: False, diff --git a/src/sagemaker/config/config_utils.py b/src/sagemaker/config/config_utils.py index 70f2764529..e81a620c9a 100644 --- a/src/sagemaker/config/config_utils.py +++ b/src/sagemaker/config/config_utils.py @@ -20,6 +20,8 @@ import logging import sys from typing import Callable +import re +from copy import deepcopy def get_sagemaker_config_logger(): @@ -67,6 +69,19 @@ def _log_sagemaker_config_single_substitution(source_value, config_value, config """ logger = get_sagemaker_config_logger() + source_value_log_copy = deepcopy(source_value) + config_value_log_copy = deepcopy(config_value) + + if isinstance(source_value_log_copy, dict): + for key in source_value_log_copy.keys(): + if re.search(r"(secret|password|key|token)", key, re.IGNORECASE): + source_value_log_copy[key] = "***" + + if isinstance(config_value_log_copy, dict): + for key in config_value_log_copy.keys(): + if re.search(r"(secret|password|key|token)", key, re.IGNORECASE): + config_value_log_copy[key] = "***" + if config_value is not None: if source_value is None: @@ -79,7 +94,7 @@ def _log_sagemaker_config_single_substitution(source_value, config_value, config logger.debug( "Applied value\n config key = %s\n config value that will be used = %s", config_key_path, - config_value, + config_value_log_copy, ) else: logger.info( @@ -102,8 +117,8 @@ def _log_sagemaker_config_single_substitution(source_value, config_value, config " source value that will be used = %s" ), config_key_path, - config_value, - source_value, + config_value_log_copy, + source_value_log_copy, ) elif source_value is not None and config_value != source_value: # Sagemaker Config had a value defined that is NOT going to be used @@ -117,8 +132,8 @@ def _log_sagemaker_config_single_substitution(source_value, config_value, config " source value that will be used = %s", ), config_key_path, - config_value, - source_value, + config_value_log_copy, + source_value_log_copy, ) else: # nothing was specified in the config and nothing is being automatically applied diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 46d0361f67..16c81d6d77 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -24,6 +24,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -37,6 +38,8 @@ def retrieve_options( retrieve the supported content types. (Default: None). model_version (str): The version of the model for which to retrieve the supported content types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -60,11 +63,12 @@ def retrieve_options( ) return artifacts._retrieve_supported_content_types( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -73,10 +77,12 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default content type for the model matching the given arguments. @@ -87,6 +93,8 @@ def retrieve_default( retrieve the default content type. (Default: None). model_version (str): The version of the model for which to retrieve the default content type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -98,6 +106,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default content type to use for the model. @@ -110,13 +119,15 @@ def retrieve_default( ) return artifacts._retrieve_default_content_type( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 1a4be43897..dad5137329 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -31,8 +31,10 @@ StreamDeserializer, StringDeserializer, TorchTensorDeserializer, + RecordDeserializer, ) +from sagemaker.deprecations import deprecated_class from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartModelType @@ -43,6 +45,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -56,6 +59,8 @@ def retrieve_options( retrieve the supported deserializers. (Default: None). model_version (str): The version of the model for which to retrieve the supported deserializers. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -80,11 +85,12 @@ def retrieve_options( ) return artifacts._retrieve_deserializer_options( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -93,10 +99,12 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. @@ -107,6 +115,8 @@ def retrieve_default( retrieve the default deserializer. (Default: None). model_version (str): The version of the model for which to retrieve the default deserializer. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -118,6 +128,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseDeserializer: The default deserializer to use for the model. @@ -131,11 +142,16 @@ def retrieve_default( ) return artifacts._retrieve_default_deserializer( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) + + +record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer") diff --git a/src/sagemaker/djl_inference/__init__.py b/src/sagemaker/djl_inference/__init__.py index 0f6b867318..dd8b005d1e 100644 --- a/src/sagemaker/djl_inference/__init__.py +++ b/src/sagemaker/djl_inference/__init__.py @@ -13,8 +13,5 @@ """Placeholder docstring""" from __future__ import absolute_import -from sagemaker.djl_inference.model import DJLPredictor # noqa: F401 +from sagemaker.djl_inference.djl_predictor import DJLPredictor # noqa: F401 from sagemaker.djl_inference.model import DJLModel # noqa: F401 -from sagemaker.djl_inference.model import DeepSpeedModel # noqa: F401 -from sagemaker.djl_inference.model import HuggingFaceAccelerateModel # noqa: F401 -from sagemaker.djl_inference.model import FasterTransformerModel # noqa: F401 diff --git a/src/sagemaker/djl_inference/defaults.py b/src/sagemaker/djl_inference/defaults.py deleted file mode 100644 index 64699de8f9..0000000000 --- a/src/sagemaker/djl_inference/defaults.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Placeholder docstring""" -from __future__ import absolute_import - -STABLE_DIFFUSION_MODEL_TYPE = "stable-diffusion" - -VALID_MODEL_CONFIG_FILES = ["config.json", "model_index.json"] - -DEEPSPEED_RECOMMENDED_ARCHITECTURES = { - "bloom", - "opt", - "gpt_neox", - "gptj", - "gpt_neo", - "gpt2", - "xlm-roberta", - "roberta", - "bert", - STABLE_DIFFUSION_MODEL_TYPE, -} - -FASTER_TRANSFORMER_RECOMMENDED_ARCHITECTURES = { - "t5", -} - -FASTER_TRANSFORMER_SUPPORTED_ARCHITECTURES = { - "bert", - "gpt2", - "bloom", - "opt", - "gptj", - "gpt_neox", - "gpt_neo", - "t5", -} - -ALLOWED_INSTANCE_FAMILIES = { - "ml.g4dn", - "ml.g5", - "ml.p3", - "ml.p3dn", - "ml.p4", - "ml.p4d", - "ml.p4de", - "local_gpu", -} - -REVISION_MAPPING = {"fp16": "float16", "fp32": "float32"} diff --git a/src/sagemaker/djl_inference/djl_predictor.py b/src/sagemaker/djl_inference/djl_predictor.py new file mode 100644 index 0000000000..e6ab10f676 --- /dev/null +++ b/src/sagemaker/djl_inference/djl_predictor.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Default Predictor for JSON inputs/outputs used with DJL LMI containers""" +from __future__ import absolute_import +from sagemaker.predictor import Predictor +from sagemaker import Session +from sagemaker.serializers import BaseSerializer, JSONSerializer +from sagemaker.deserializers import BaseDeserializer, JSONDeserializer + + +class DJLPredictor(Predictor): + """A Predictor for inference against DJL Model Endpoints. + + This is able to serialize Python lists, dictionaries, and numpy arrays to + multidimensional tensors for DJL inference. + """ + + def __init__( + self, + endpoint_name: str, + sagemaker_session: Session = None, + serializer: BaseSerializer = JSONSerializer(), + deserializer: BaseDeserializer = JSONDeserializer(), + component_name=None, + ): + """Initialize a ``DJLPredictor`` + + Args: + endpoint_name (str): The name of the endpoint to perform inference + on. + sagemaker_session (sagemaker.session.Session): Session object that + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, the estimator creates one + using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to json format. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses the response from json format to dictionary. + component_name (str): Optional. Name of the Amazon SageMaker inference + component corresponding the predictor. + """ + super(DJLPredictor, self).__init__( + endpoint_name, + sagemaker_session, + serializer=serializer, + deserializer=deserializer, + component_name=component_name, + ) diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index efbb44460c..94db4efe29 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -13,308 +13,52 @@ """Placeholder docstring""" from __future__ import absolute_import -import json import logging -import os.path -import urllib.request -from json import JSONDecodeError -from urllib.error import HTTPError, URLError -from enum import Enum -from typing import Optional, Union, Dict, Any, List +from typing import Callable, Optional, Dict, Any -import sagemaker -from sagemaker import s3, Predictor, image_uris, fw_utils -from sagemaker.deserializers import JSONDeserializer, BaseDeserializer -from sagemaker.djl_inference import defaults -from sagemaker.model import FrameworkModel -from sagemaker.s3_utils import s3_path_join -from sagemaker.serializers import JSONSerializer, BaseSerializer +from sagemaker import image_uris +from sagemaker.model import Model from sagemaker.session import Session -from sagemaker.utils import _tmpdir, _create_or_update_code_dir, format_tags -from sagemaker.workflow.entities import PipelineVariable -from sagemaker.estimator import Estimator -from sagemaker.s3 import S3Uploader -logger = logging.getLogger("sagemaker") +from sagemaker.djl_inference.djl_predictor import DJLPredictor -# DJL Serving uses log4j, so we convert python logging level to log4j equivalent -_LOG_LEVEL_MAP = { - logging.INFO: "info", - logging.DEBUG: "debug", - logging.WARNING: "warn", - logging.ERROR: "error", - logging.FATAL: "fatal", - logging.CRITICAL: "fatal", - logging.NOTSET: "off", -} +logger = logging.getLogger(__name__) -class DJLServingEngineEntryPointDefaults(Enum): - """Enum describing supported engines and corresponding default inference handler modules.""" +def _set_env_var_from_property( + property_value: Optional[Any], env_key: str, env: dict, override_env_var=False +) -> dict: + """Utility method to set an environment variable configuration""" + if not property_value: + return env + if override_env_var or env_key not in env: + env[env_key] = str(property_value) + return env - DEEPSPEED = ("DeepSpeed", "djl_python.deepspeed") - HUGGINGFACE_ACCELERATE = ("Python", "djl_python.huggingface") - STABLE_DIFFUSION = ("DeepSpeed", "djl_python.stable-diffusion") - FASTER_TRANSFORMER = ("FasterTransformer", "djl_python.fastertransformer") - -class DJLPredictor(Predictor): - """A Predictor for inference against DJL Model Endpoints. - - This is able to serialize Python lists, dictionaries, and numpy arrays to - multidimensional tensors for DJL inference. - """ - - def __init__( - self, - endpoint_name: str, - sagemaker_session: Session = None, - serializer: BaseSerializer = JSONSerializer(), - deserializer: BaseDeserializer = JSONDeserializer(), - component_name=None, - ): - """Initialize a ``DJLPredictor`` - - Args: - endpoint_name (str): The name of the endpoint to perform inference - on. - sagemaker_session (sagemaker.session.Session): Session object that - manages interactions with Amazon SageMaker APIs and any other - AWS services needed. If not specified, the estimator creates one - using the default AWS configuration chain. - serializer (sagemaker.serializers.BaseSerializer): Optional. Default - serializes input data to json format. - deserializer (sagemaker.deserializers.BaseDeserializer): Optional. - Default parses the response from json format to dictionary. - component_name (str): Optional. Name of the Amazon SageMaker inference - component corresponding the predictor. - """ - super(DJLPredictor, self).__init__( - endpoint_name, - sagemaker_session, - serializer=serializer, - deserializer=deserializer, - component_name=component_name, - ) - - -def _determine_engine_for_model(model_type: str, num_partitions: int, num_heads: int): - """Placeholder docstring""" - - # Tensor Parallelism is only possible if attention heads can be split evenly - # across devices - if num_heads is not None and num_partitions is not None and num_heads % num_partitions: - return HuggingFaceAccelerateModel - if model_type in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES: - return DeepSpeedModel - if model_type in defaults.FASTER_TRANSFORMER_RECOMMENDED_ARCHITECTURES: - return FasterTransformerModel - return HuggingFaceAccelerateModel - - -def _validate_engine_for_model_type(cls, model_type: str, num_partitions: int, num_heads: int): - """Placeholder docstring""" - - if cls == DeepSpeedModel: - if num_heads is not None and num_partitions is not None and num_heads % num_partitions: - raise ValueError( - "The number of attention heads is not evenly divisible by the number of partitions." - "Please set the number of partitions such that the number of attention heads can be" - "evenly split across the partitions." - ) - if cls == FasterTransformerModel: - if model_type not in defaults.FASTER_TRANSFORMER_SUPPORTED_ARCHITECTURES: - raise ValueError( - f"The model architecture {model_type} is currently not supported by " - f"FasterTransformer. Please use a different engine, or use the DJLModel" - f"to let SageMaker pick a recommended engine for this model." - ) - return cls - - -def _read_existing_serving_properties(directory: str): - """Placeholder docstring""" - - serving_properties_path = os.path.join(directory, "serving.properties") - properties = {} - if os.path.exists(serving_properties_path): - with open(serving_properties_path, "r") as f: - for line in f: - if line.startswith("#") or len(line.strip()) == 0: - continue - key, val = line.split("=", 1) - properties[key] = val - return properties - - -def _get_model_config_properties_from_s3(model_s3_uri: str, sagemaker_session: Session): - """Placeholder docstring""" - - s3_files = s3.S3Downloader.list(model_s3_uri, sagemaker_session=sagemaker_session) - model_config = None - for config in defaults.VALID_MODEL_CONFIG_FILES: - config_file = os.path.join(model_s3_uri, config) - if config_file in s3_files: - model_config = json.loads( - s3.S3Downloader.read_file(config_file, sagemaker_session=sagemaker_session) - ) - break - if not model_config: - raise ValueError( - f"Did not find a config.json or model_index.json file in {model_s3_uri}. Please make " - f"sure a config.json exists (or model_index.json for Stable Diffusion Models) in" - f"the provided s3 location" - ) - return model_config - - -def _get_model_config_properties_from_hf(model_id: str, hf_hub_token: str = None): - """Placeholder docstring""" - - config_url_prefix = f"https://huggingface.co/{model_id}/raw/main/" - model_config = None - for config in defaults.VALID_MODEL_CONFIG_FILES: - config_file_url = config_url_prefix + config - try: - if hf_hub_token: - config_file_url = urllib.request.Request( - config_file_url, None, {"Authorization": "Bearer " + hf_hub_token} - ) - with urllib.request.urlopen(config_file_url) as response: - model_config = json.load(response) - break - except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e: - if "HTTP Error 401: Unauthorized" in str(e): - raise ValueError( - "Trying to access a gated/private HuggingFace model without valid credentials. " - "Please provide a HUGGING_FACE_HUB_TOKEN in env_vars" - ) - logger.warning( - "Exception encountered while trying to read config file %s. " "Details: %s", - config_file_url, - e, - ) - if not model_config: - raise ValueError( - f"Did not find a config.json or model_index.json file in huggingface hub for " - f"{model_id}. Please make sure a config.json exists (or model_index.json for Stable " - f"Diffusion Models) for this model in the huggingface hub" - ) - return model_config - - -def _create_estimator( - instance_type: str, - s3_output_uri: str, - image_uri: str, - role: str, - sagemaker_session: Optional[Session], - volume_size: int, - vpc_config: Optional[ - Dict[ - str, - List[str], - ] - ] = None, - volume_kms_key=None, - output_kms_key=None, - use_spot_instances: bool = False, - max_wait: int = None, - enable_network_isolation: bool = False, -): - """Placeholder docstring""" - - subnets = None - security_group_ids = None - if vpc_config: - subnets = vpc_config.get("Subnets") - security_group_ids = vpc_config.get("SecurityGroupIds") - - return Estimator( - image_uri=image_uri, - role=role, - instance_count=1, - instance_type=instance_type, - volume_size=volume_size, - volume_kms_key=volume_kms_key, - output_path=s3_output_uri, - output_kms_key=output_kms_key, - sagemaker_session=sagemaker_session, - subnets=subnets, - security_group_ids=security_group_ids, - use_spot_instances=use_spot_instances, - max_wait=max_wait, - enable_network_isolation=enable_network_isolation, - ) - - -class DJLModel(FrameworkModel): +class DJLModel(Model): """A DJL SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``.""" - def __new__( - cls, - model_id: str, - *args, - **kwargs, - ): # pylint: disable=W0613 - """Create a specific subclass of DJLModel for a given engine""" - if model_id.endswith("tar.gz"): - raise ValueError( - "DJLModel does not support model artifacts in tar.gz format." - "Please store the model in uncompressed format and provide the s3 uri of the " - "containing folder" - ) - if model_id.startswith("s3://"): - sagemaker_session = kwargs.get("sagemaker_session") - model_config = _get_model_config_properties_from_s3(model_id, sagemaker_session) - else: - hf_hub_token = kwargs.get("hf_hub_token") - model_config = _get_model_config_properties_from_hf(model_id, hf_hub_token) - if model_config.get("_class_name") == "StableDiffusionPipeline": - model_type = defaults.STABLE_DIFFUSION_MODEL_TYPE - num_heads = 0 - else: - model_type = model_config.get("model_type") - num_heads = model_config.get("n_head") or model_config.get("num_attention_heads") - number_of_partitions = kwargs.get("number_of_partitions") or kwargs.get( - "tensor_parallel_degree" - ) - cls_to_create = ( - _validate_engine_for_model_type(cls, model_type, number_of_partitions, num_heads) - if cls is not DJLModel - else _determine_engine_for_model(model_type, number_of_partitions, num_heads) - ) - instance = super().__new__(cls_to_create) - if model_type == defaults.STABLE_DIFFUSION_MODEL_TYPE: - instance.engine = DJLServingEngineEntryPointDefaults.STABLE_DIFFUSION - elif isinstance(instance, DeepSpeedModel): - instance.engine = DJLServingEngineEntryPointDefaults.DEEPSPEED - elif isinstance(instance, FasterTransformerModel): - instance.engine = DJLServingEngineEntryPointDefaults.FASTER_TRANSFORMER - else: - instance.engine = DJLServingEngineEntryPointDefaults.HUGGINGFACE_ACCELERATE - return instance - def __init__( self, - model_id: str, - role: str, - djl_version: Optional[str] = None, + model_id: Optional[str] = None, + engine: Optional[str] = None, + djl_version: str = "latest", + djl_framework: Optional[str] = None, task: Optional[str] = None, - dtype: str = "fp32", - number_of_partitions: Optional[int] = None, + dtype: Optional[str] = None, + tensor_parallel_degree: Optional[int] = None, min_workers: Optional[int] = None, max_workers: Optional[int] = None, job_queue_size: Optional[int] = None, parallel_loading: bool = False, model_loading_timeout: Optional[int] = None, prediction_timeout: Optional[int] = None, - entry_point: Optional[str] = None, - image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = DJLPredictor, + predictor_cls: Optional[Callable] = DJLPredictor, + huggingface_hub_token: Optional[str] = None, **kwargs, ): - """Initialize a DJLModel. + """Initialize a SageMaker model using one of the DJL Model Serving Containers. Args: model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location @@ -322,24 +66,23 @@ def __init__( The model artifacts are expected to be in HuggingFace pre-trained model format (i.e. model should be loadable from the huggingface transformers from_pretrained api, and should also include tokenizer configs if applicable). - role (str): An AWS IAM role specified with either the name or full ARN. The Amazon - SageMaker training jobs and APIs that create Amazon SageMaker - endpoints use this role to access model artifacts. After the endpoint is created, - the inference code might use the IAM role, if it needs to access an AWS resource. + model artifact location must be specified using either the model_id parameter, + model_data parameter, or HF_MODEL_ID environment variable in the env parameter + engine (str): The DJL inference engine to use for your model. Defaults to None. + If not provided, the engine is inferred based on the task. If no task is provided, + the Python engine is used. djl_version (str): DJL Serving version you want to use for serving your model for inference. Defaults to None. If not provided, the latest available version of DJL Serving is used. This is not used if ``image_uri`` is provided. + djl_framework (str): The DJL container to use. This is used along with djl_version + to fetch the image_uri of the djl inference container. This is not used if + ``image_uri`` is provided. task (str): The HuggingFace/NLP task you want to launch this model for. Defaults to None. If not provided, the task will be inferred from the model architecture by DJL. - dtype (str): The data type to use for loading your model. Accepted values are - "fp32", "fp16", "bf16", "int8". Defaults to "fp32". - number_of_partitions (int): The number of GPUs to partition the model across. The - partitioning strategy is determined by the selected backend. If DeepSpeed is - selected, this is tensor parallelism. - If HuggingFace Accelerate is selected, this is a naive sharding strategy - that splits the model layers across the available resources. Defaults to None. If - not provided, no model partitioning is done. + tensor_parallel_degree (int): The number of accelerators to partition the model across + using tensor parallelism. Defaults to None. If not provided, the maximum number + of available accelerators will be used. min_workers (int): The minimum number of worker processes. Defaults to None. If not provided, dJL Serving will automatically detect the minimum workers. max_workers (int): The maximum number of worker processes. Defaults to None. If not @@ -354,58 +97,26 @@ def __init__( None. If not provided, the default is 240 seconds. prediction_timeout (int): The worker predict call (handler) timeout in seconds. Defaults to None. If not provided, the default is 120 seconds. - entry_point (str): This can either be the absolute or relative path to the Python source - file that should be executed as the entry point to model - hosting, or a python module that is installed in the container. If ``source_dir`` - is specified, then ``entry_point`` - must point to a file located at the root of ``source_dir``. Defaults to None. - image_uri (str): A docker image URI. Defaults to None. If not specified, a default - image for DJL Serving will be used based on ``djl_version``. If ``djl_version`` - is not specified, the latest available container version will be used. - predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a - predictor with an endpoint name and SageMaker ``Session``. If specified, - ``deploy()`` returns - the result of invoking this function on the created endpoint name. + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call + to create a predictor with an endpoint name and SageMaker ``Session``. If + specified, ``deploy()`` returns the result of invoking this function on the created + endpoint name. + huggingface_hub_token (str): The HuggingFace Hub token to use for downloading the model + artifacts for a model stored on the huggingface hub. + Defaults to None. If not provided, the token must be specified in the + HF_TOKEN environment variable in the env parameter. **kwargs: Keyword arguments passed to the superclass :class:`~sagemaker.model.FrameworkModel` and, subsequently, its superclass :class:`~sagemaker.model.Model`. - - .. tip:: - - Instantiating a DJLModel will return an instance of either - :class:`~sagemaker.djl_inference.DeepSpeedModel` or - :class:`~sagemaker.djl_inference.HuggingFaceAccelerateModel` based on our framework - recommendation for the model type. - - If you want to use a specific framework to deploy your model with, we recommend - instantiating that specific - model class directly. The available framework specific classes are - :class:`~sagemaker.djl_inference.DeepSpeedModel` or - :class:`~sagemaker.djl_inference.HuggingFaceAccelerateModel` """ - if "hf_hub_token" in kwargs: - kwargs.pop("hf_hub_token") - if kwargs.get("model_data"): - logger.warning( - "DJLModels do not use model_data parameter. model_data parameter will be ignored." - "You only need to set model_id and ensure it points to uncompressed model " - "artifacts in s3, or a valid HuggingFace Hub model_id." - ) - data_type = kwargs.pop("data_type", None) - if data_type: - logger.warning( - "data_type is being deprecated in favor of dtype. Please migrate use of data_type" - " to dtype. Support for data_type will be removed in a future release" - ) - dtype = dtype or data_type - super(DJLModel, self).__init__( - None, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs - ) + super(DJLModel, self).__init__(predictor_cls=predictor_cls, **kwargs) self.model_id = model_id self.djl_version = djl_version + self.djl_framework = djl_framework + self.engine = engine self.task = task self.dtype = dtype - self.number_of_partitions = number_of_partitions + self.tensor_parallel_degree = tensor_parallel_degree self.min_workers = min_workers self.max_workers = max_workers self.job_queue_size = job_queue_size @@ -413,7 +124,85 @@ def __init__( self.model_loading_timeout = model_loading_timeout self.prediction_timeout = prediction_timeout self.sagemaker_session = self.sagemaker_session or Session() - self.save_mp_checkpoint_path = None + self.hub_token = huggingface_hub_token + self._initialize_model() + + def _initialize_model(self): + """Placeholder docstring""" + self._validate_model_artifacts() + self.engine = self._infer_engine() + self.env = self._configure_environment_variables() + self.image_uri = self._infer_image_uri() + + def _validate_model_artifacts(self): + """Placeholder docstring""" + if self.model_id is not None and self.model_data is not None: + raise ValueError( + "both model_id and model_data are provided. Please only provide one of them" + ) + + def _infer_engine(self) -> Optional[str]: + """Placeholder docstring""" + if self.engine is not None: + logger.info("Using provided engine %s", self.engine) + return self.engine + + if self.task == "text-embedding": + return "OnnxRuntime" + return "Python" + + def _infer_image_uri(self): + """Placeholder docstring""" + if self.image_uri is not None: + return self.image_uri + if self.djl_framework is None: + self.djl_framework = "djl-lmi" + return image_uris.retrieve( + framework=self.djl_framework, + region=self.sagemaker_session.boto_region_name, + version=self.djl_version, + ) + + def _configure_environment_variables(self) -> Dict[str, str]: + """Placeholder docstring""" + env = self.env.copy() if self.env else {} + env = _set_env_var_from_property(self.model_id, "HF_MODEL_ID", env) + env = _set_env_var_from_property(self.task, "HF_TASK", env) + env = _set_env_var_from_property(self.dtype, "OPTION_DTYPE", env) + env = _set_env_var_from_property(self.min_workers, "SERVING_MIN_WORKERS", env) + env = _set_env_var_from_property(self.max_workers, "SERVING_MAX_WORKERS", env) + env = _set_env_var_from_property(self.job_queue_size, "SERVING_JOB_QUEUE_SIZE", env) + env = _set_env_var_from_property(self.parallel_loading, "OPTION_PARALLEL_LOADING", env) + env = _set_env_var_from_property( + self.model_loading_timeout, "OPTION_MODEL_LOADING_TIMEOUT", env + ) + env = _set_env_var_from_property(self.prediction_timeout, "OPTION_PREDICT_TIMEOUT", env) + env = _set_env_var_from_property(self.hub_token, "HF_TOKEN", env) + env = _set_env_var_from_property(self.engine, "OPTION_ENGINE", env) + if "TENSOR_PARALLEL_DEGREE" not in env or "OPTION_TENSOR_PARALLEL_DEGREE" not in env: + if self.tensor_parallel_degree is not None: + env["TENSOR_PARALLEL_DEGREE"] = str(self.tensor_parallel_degree) + return env + + def serving_image_uri( + self, + region_name, + instance_type=None, + accelerator_type=None, + serverless_inference_config=None, + ): + """Placeholder docstring""" + if self.image_uri: + return self.image_uri + return image_uris.retrieve( + framework=self.djl_framework, + region=region_name, + version=self.djl_version, + instance_type=instance_type, + accelerator_type=accelerator_type, + image_scope="inference", + serverless_inference_config=serverless_inference_config, + ) def package_for_edge(self, **_): """Not implemented. @@ -460,791 +249,3 @@ def right_size(self, **_): raise NotImplementedError( "DJLModels do not currently support Inference Recommendation Jobs" ) - - def partition( - self, - instance_type: str, - s3_output_uri: str = None, - s3_output_prefix: str = "aot-partitioned-checkpoints", - job_name: Optional[str] = None, - volume_size: int = 30, - volume_kms_key: Optional[str] = None, - output_kms_key: Optional[str] = None, - use_spot_instances: bool = False, - max_wait: int = None, - enable_network_isolation: bool = False, - ): - """Partitions the model using SageMaker Training Job. This is a synchronous API call. - - Args: - instance_type (str): The EC2 instance type to partition this Model. - For example, 'ml.p4d.24xlarge'. - s3_output_uri (str): S3 location for saving the training result (model - artifacts and output files). If not specified, results are - stored to a default bucket. If the bucket with the specific name - does not exist, it will be created. - s3_output_prefix (str): Name of the prefix where all the partitioned - checkpoints to be uploaded. If not provided, the default value is - aot-partitioned-checkpoints. - job_name (str): Training job name. If not specified, a unique training job - name will be created. - volume_size (int): Size in GB of the storage volume to use for - storing input and output data during training (default: 30). - volume_kms_key (str): Optional. KMS key ID for encrypting EBS - volume attached to the training instance (default: None). - output_kms_key (str): Optional. KMS key ID for encrypting the - training output (default: None). - use_spot_instances (bool): Specifies whether to use SageMaker - Managed Spot instances for training. If enabled then the - ``max_wait`` arg should also be set. - - More information: - https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html - (default: ``False``). - max_wait (int): Timeout in seconds waiting for spot training - job (default: None). After this amount of time Amazon - SageMaker will stop waiting for managed spot training job to - complete (default: None). - enable_network_isolation (bool): Specifies whether container will - run in network isolation mode (default: ``False``). Network - isolation mode restricts the container access to outside networks - (such as the Internet). The container does not make any inbound or - outbound network calls. Also known as Internet-free mode. - Returns: - None - """ - - if not self.image_uri: - region_name = self.sagemaker_session.boto_session.region_name - self.image_uri = self.serving_image_uri(region_name) - - if s3_output_uri is None: - deploy_key_prefix = fw_utils.model_code_key_prefix( - self.key_prefix, self.name, self.image_uri - ) - - bucket, deploy_key_prefix = s3.determine_bucket_and_prefix( - bucket=self.bucket, - key_prefix=deploy_key_prefix, - sagemaker_session=self.sagemaker_session, - ) - s3_output_uri = s3_path_join("s3://", bucket, deploy_key_prefix) - - self.save_mp_checkpoint_path = s3_path_join(s3_output_uri, s3_output_prefix) - - container_def = self._upload_model_to_s3(upload_as_tar=False) - estimator = _create_estimator( - instance_type=instance_type, - s3_output_uri=s3_output_uri, - image_uri=self.image_uri, - role=self.role, - sagemaker_session=self.sagemaker_session, - volume_size=volume_size, - vpc_config=self.vpc_config, - volume_kms_key=volume_kms_key, - output_kms_key=output_kms_key, - use_spot_instances=use_spot_instances, - max_wait=max_wait, - enable_network_isolation=enable_network_isolation, - ) - - # creates a training job to do partitions - estimator.fit( - inputs=container_def["ModelDataUrl"], - wait=True, - logs="All", - job_name=job_name, - experiment_config=None, - ) - - self.model_id = self.save_mp_checkpoint_path - # reset save_mp_checkpoint_path since partition is completed. - self.save_mp_checkpoint_path = None - - def deploy( - self, - instance_type, - initial_instance_count=1, - serializer=None, - deserializer=None, - endpoint_name=None, - tags=None, - kms_key=None, - wait=True, - data_capture_config=None, - volume_size=None, - model_data_download_timeout=None, - container_startup_health_check_timeout=None, - **kwargs, - ): - """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. - - Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an - ``Endpoint`` from this ``Model``. If ``self.predictor_cls`` is not None, - this method returns the result of invoking ``self.predictor_cls`` on - the created endpoint name. - - The name of the created model is accessible in the ``name`` field of - this ``Model`` after deploy returns - - The name of the created endpoint is accessible in the - ``endpoint_name`` field of this ``Model`` after deploy returns. - - Args: - instance_type (str): The EC2 instance type to deploy this Model to. - For example, 'ml.p4d.24xlarge'. - initial_instance_count (int): The initial number of instances to run - in the ``Endpoint`` created from this ``Model``. It needs to be at least 1 ( - default: 1) - serializer (:class:`~sagemaker.serializers.BaseSerializer`): A - serializer object, used to encode data for an inference endpoint - (default: None). If ``serializer`` is not None, then - ``serializer`` will override the default serializer. The - default serializer is set by the ``predictor_cls``. - deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A - deserializer object, used to decode data from an inference - endpoint (default: None). If ``deserializer`` is not None, then - ``deserializer`` will override the default deserializer. The - default deserializer is set by the ``predictor_cls``. - endpoint_name (str): The name of the endpoint to create (default: - None). If not specified, a unique endpoint name will be created. - tags (Optional[Tags]): The list of tags to attach to this - specific endpoint. - kms_key (str): The ARN of the KMS key that is used to encrypt the - data on the storage volume attached to the instance hosting the - endpoint. - wait (bool): Whether the call should wait until the deployment of - this model completes (default: True). - data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies - configuration related to Endpoint data capture for use with - Amazon SageMaker Model Monitoring. Default: None. - volume_size (int): The size, in GB, of the ML storage volume attached to individual - inference instance associated with the production variant. Currenly only Amazon EBS - gp2 storage volumes are supported. - model_data_download_timeout (int): The timeout value, in seconds, to download and - extract model data from Amazon S3 to the individual inference instance associated - with this production variant. - container_startup_health_check_timeout (int): The timeout value, in seconds, for your - inference container to pass health check by SageMaker Hosting. For more information - about health check see: - https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code - .html#your-algorithms-inference-algo-ping-requests - - Returns: - callable[string, sagemaker.session.Session] or None: Invocation of - ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` - is not None. Otherwise, return None. - """ - - instance_family = instance_type.rsplit(".", 1)[0] - if instance_family not in defaults.ALLOWED_INSTANCE_FAMILIES: - raise ValueError( - f"Invalid instance type. DJLModels only support deployment to instances" - f"with GPUs. Supported instance families are {defaults.ALLOWED_INSTANCE_FAMILIES}" - ) - - return super(DJLModel, self).deploy( - initial_instance_count=initial_instance_count, - instance_type=instance_type, - serializer=serializer, - deserializer=deserializer, - endpoint_name=endpoint_name, - tags=format_tags(tags), - kms_key=kms_key, - wait=wait, - data_capture_config=data_capture_config, - volume_size=volume_size, - model_data_download_timeout=model_data_download_timeout, - container_startup_health_check_timeout=container_startup_health_check_timeout, - **kwargs, - ) - - def _upload_model_to_s3(self, upload_as_tar: bool = True): - """Placeholder docstring""" - - if not self.image_uri: - region_name = self.sagemaker_session.boto_session.region_name - self.image_uri = self.serving_image_uri(region_name) - - environment = self._get_container_env() - - local_download_dir = ( - None - if self.sagemaker_session.settings is None - or self.sagemaker_session.settings.local_download_dir is None - else self.sagemaker_session.settings.local_download_dir - ) - with _tmpdir(directory=local_download_dir) as tmp: - if self.source_dir or self.entry_point: - # Below method downloads from s3, or moves local files to tmp/code - _create_or_update_code_dir( - tmp, - self.entry_point, - self.source_dir, - self.dependencies, - self.sagemaker_session, - tmp, - ) - tmp_code_dir = os.path.join(tmp, "code") - existing_serving_properties = _read_existing_serving_properties(tmp_code_dir) - kwargs_serving_properties = self.generate_serving_properties() - existing_serving_properties.update(kwargs_serving_properties) - - if not os.path.exists(tmp_code_dir): - os.mkdir(tmp_code_dir) - with open(os.path.join(tmp_code_dir, "serving.properties"), "w+") as f: - for key, val in existing_serving_properties.items(): - f.write(f"{key}={val}\n") - - deploy_key_prefix = fw_utils.model_code_key_prefix( - self.key_prefix, self.name, self.image_uri - ) - bucket, deploy_key_prefix = s3.determine_bucket_and_prefix( - bucket=self.bucket, - key_prefix=deploy_key_prefix, - sagemaker_session=self.sagemaker_session, - ) - if upload_as_tar: - uploaded_code = fw_utils.tar_and_upload_dir( - self.sagemaker_session.boto_session, - bucket, - deploy_key_prefix, - self.entry_point, - directory=tmp_code_dir, - dependencies=self.dependencies, - kms_key=self.model_kms_key, - ) - model_data_url = uploaded_code.s3_prefix - else: - model_data_url = S3Uploader.upload( - tmp_code_dir, - s3_path_join("s3://", bucket, deploy_key_prefix, "aot-model"), - self.model_kms_key, - self.sagemaker_session, - ) - return sagemaker.container_def( - self.image_uri, model_data_url=model_data_url, env=environment - ) - - def prepare_container_def( - self, - instance_type=None, - accelerator_type=None, - serverless_inference_config=None, - accept_eula=None, - ): # pylint: disable=unused-argument - """A container definition with framework configuration set in model environment variables. - - Returns: - dict[str, str]: A container definition object usable with the - CreateModel API. - """ - - if not self.model_data and not isinstance(self.model_data, dict): - return self._upload_model_to_s3(upload_as_tar=True) - return super().prepare_container_def( - instance_type, accelerator_type, serverless_inference_config - ) - - def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]: - """Generates the DJL Serving configuration to use for the model. - - The configuration is generated using the arguments passed to the Model during - initialization. If a serving.properties file is found in ``self.source_dir``, - those configuration as merged with the Model parameters, with Model parameters taking - priority. - - Args: - serving_properties: Dictionary containing existing model server configuration - obtained from ``self.source_dir``. Defaults to None. - - Returns: - dict: The model server configuration to use when deploying this model to SageMaker. - """ - if not serving_properties: - serving_properties = {} - serving_properties["engine"] = self.engine.value[0] # pylint: disable=E1101 - serving_properties["option.entryPoint"] = self.engine.value[1] # pylint: disable=E1101 - serving_properties["option.model_id"] = self.model_id - if self.number_of_partitions: - serving_properties["option.tensor_parallel_degree"] = self.number_of_partitions - if self.entry_point: - serving_properties["option.entryPoint"] = self.entry_point - if self.task: - serving_properties["option.task"] = self.task - if self.dtype: - serving_properties["option.dtype"] = self.dtype - if self.min_workers: - serving_properties["minWorkers"] = self.min_workers - if self.max_workers: - serving_properties["maxWorkers"] = self.max_workers - if self.job_queue_size: - serving_properties["job_queue_size"] = self.job_queue_size - if self.parallel_loading: - serving_properties["option.parallel_loading"] = self.parallel_loading - if self.model_loading_timeout: - serving_properties["option.model_loading_timeout"] = self.model_loading_timeout - if self.prediction_timeout: - serving_properties["option.prediction_timeout"] = self.prediction_timeout - if self.save_mp_checkpoint_path: - serving_properties["option.save_mp_checkpoint_path"] = self.save_mp_checkpoint_path - return serving_properties - - def serving_image_uri(self, region_name): - """Create a URI for the serving image. - - Args: - region_name (str): AWS region where the image is uploaded. - - Returns: - str: The appropriate image URI based on the given parameters. - """ - if not self.djl_version: - self.djl_version = "0.24.0" - - return image_uris.retrieve( - self._framework(), - region_name, - version=self.djl_version, - ) - - def _get_container_env(self): - """Placeholder docstring""" - - if not self.container_log_level: - return self.env - - if self.container_log_level not in _LOG_LEVEL_MAP: - logger.warning("Ignoring invalid container log level: %s", self.container_log_level) - return self.env - - self.env["SERVING_OPTS"] = ( - f'"-Dai.djl.logging.level={_LOG_LEVEL_MAP[self.container_log_level]}"' - ) - return self.env - - -class DeepSpeedModel(DJLModel): - """A DJL DeepSpeed SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``""" - - _framework_name = "djl-deepspeed" - - def __init__( - self, - model_id: str, - role: str, - tensor_parallel_degree: Optional[int] = None, - max_tokens: Optional[int] = None, - low_cpu_mem_usage: bool = False, - enable_cuda_graph: bool = False, - triangular_masking: bool = True, - return_tuple: bool = True, - **kwargs, - ): - """Initialize a DeepSpeedModel - - Args: - model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location - containing the uncompressed model artifacts (i.e. not a tar.gz file). - The model artifacts are expected to be in HuggingFace pre-trained model - format (i.e. model should be loadable from the huggingface transformers - from_pretrained api, and should also include tokenizer configs if applicable). - role (str): An AWS IAM role specified with either the name or full ARN. The Amazon - SageMaker training jobs and APIs that create Amazon SageMaker - endpoints use this role to access model artifacts. After the endpoint is created, - the inference code - might use the IAM role, if it needs to access an AWS resource. - tensor_parallel_degree (int): The number of gpus to shard a single instance of the - model across via tensor_parallelism. This should be set to greater than 1 if the - size of the model is larger than the memory available on a single GPU on the - instance. Defaults to None. If not set, no tensor parallel sharding is done. - max_tokens (int): The maximum number of tokens (input + output tokens) the DeepSpeed - engine is configured for. Defaults to None. If not set, the DeepSpeed default of - 1024 is used. - low_cpu_mem_usage (bool): Whether to limit CPU memory usage to 1x model size during - model loading. This is an experimental feature in HuggingFace. This is useful when - loading multiple instances of your model in parallel. Defaults to False. - enable_cuda_graph (bool): Whether to enable CUDA graph replay to accelerate inference - passes. This cannot be used with tensor parallelism greater than 1. - Defaults to False. - triangular_masking (bool): Whether to use triangular attention mask. This is - application specific. Defaults to True. - return_tuple (bool): Whether the transformer layers need to return a tuple or a - Tensor. Defaults to True. - **kwargs: Keyword arguments passed to the superclasses - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model` - - .. tip:: - - You can find additional parameters for initializing this class at - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model`. - """ - if "hf_hub_token" in kwargs: - kwargs.pop("hf_hub_token") - super(DeepSpeedModel, self).__init__( - model_id, - role, - **kwargs, - ) - if self.number_of_partitions and tensor_parallel_degree: - logger.warning( - "Both number_of_partitions and tensor_parallel_degree have been set for " - "DeepSpeedModel." - "These mean the same thing for DeepSpeedModel. Please only set " - "tensor_parallel_degree." - "number_of_partitions will be ignored" - ) - self.number_of_partitions = tensor_parallel_degree or self.number_of_partitions - self.max_tokens = max_tokens - self.low_cpu_mem_usage = low_cpu_mem_usage - self.enable_cuda_graph = enable_cuda_graph - self.triangular_masking = triangular_masking - self.return_tuple = return_tuple - self.save_mp_checkpoint_path = None - self.checkpoint = None - - def generate_serving_properties(self, serving_properties=None) -> Dict[str, Any]: - """Generates the DJL Serving configuration to use for the model. - - The configuration is generated using the arguments passed to the Model during - initialization. If a serving.properties file is found in ``self.source_dir``, - those configuration as merged with the Model parameters, with Model parameters taking - priority. - - Args: - serving_properties: Dictionary containing existing model server configuration - obtained from ``self.source_dir``. Defaults to None. - - Returns: - dict: The model server configuration to use when deploying this model to SageMaker. - """ - - serving_properties = super(DeepSpeedModel, self).generate_serving_properties( - serving_properties=serving_properties - ) - if self.max_tokens: - serving_properties["option.max_tokens"] = self.max_tokens - if self.low_cpu_mem_usage: - serving_properties["option.low_cpu_mem_usage"] = self.low_cpu_mem_usage - if self.enable_cuda_graph: - if self.number_of_partitions > 1: - raise ValueError( - "enable_cuda_graph is not supported when tensor_parallel_degree > 1" - ) - serving_properties["option.enable_cuda_graph"] = self.enable_cuda_graph - if self.triangular_masking: - serving_properties["option.triangular_masking"] = self.triangular_masking - if self.return_tuple: - serving_properties["option.return_tuple"] = self.return_tuple - if self.save_mp_checkpoint_path: - serving_properties["option.save_mp_checkpoint_path"] = self.save_mp_checkpoint_path - if self.checkpoint: - serving_properties["option.checkpoint"] = self.checkpoint - - return serving_properties - - def partition( - self, - instance_type: str, - s3_output_uri: str = None, - s3_output_prefix: str = "aot-partitioned-checkpoints", - job_name: Optional[str] = None, - volume_size: int = 30, - volume_kms_key: Optional[str] = None, - output_kms_key: Optional[str] = None, - use_spot_instances: bool = False, - max_wait: int = None, - enable_network_isolation: bool = False, - ): - """Partitions the model using SageMaker Training Job. This is a synchronous API call. - - Args: - instance_type (str): The EC2 instance type to partition this Model. - For example, 'ml.p4d.24xlarge'. - s3_output_uri (str): S3 location for saving the training result (model - artifacts and output files). If not specified, results are - stored to a default bucket. If the bucket with the specific name - does not exist, it will be created. - s3_output_prefix (str): Name of the prefix where all the partitioned - checkpoints to be uploaded. If not provided, the default value is - aot-partitioned-checkpoints. - job_name (str): Training job name. If not specified, a unique training job - name will be created. - volume_size (int): Size in GB of the storage volume to use for - storing input and output data during training (default: 30). - volume_kms_key (str): Optional. KMS key ID for encrypting EBS - volume attached to the training instance (default: None). - output_kms_key (str): Optional. KMS key ID for encrypting the - training output (default: None). - use_spot_instances (bool): Specifies whether to use SageMaker - Managed Spot instances for training. If enabled then the - ``max_wait`` arg should also be set. - - More information: - https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html - (default: ``False``). - max_wait (int): Timeout in seconds waiting for spot training - job (default: None). After this amount of time Amazon - SageMaker will stop waiting for managed spot training job to - complete (default: None). - enable_network_isolation (bool): Specifies whether container will - run in network isolation mode (default: ``False``). Network - isolation mode restricts the container access to outside networks - (such as the Internet). The container does not make any inbound or - outbound network calls. Also known as Internet-free mode. - Returns: - None - """ - - super(DeepSpeedModel, self).partition( - instance_type, - s3_output_uri, - s3_output_prefix=s3_output_prefix, - job_name=job_name, - volume_size=volume_size, - volume_kms_key=volume_kms_key, - output_kms_key=output_kms_key, - use_spot_instances=use_spot_instances, - max_wait=max_wait, - enable_network_isolation=enable_network_isolation, - ) - - self.checkpoint = "ds_inference_config.json" - - -class HuggingFaceAccelerateModel(DJLModel): - """A DJL Hugging Face SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``.""" - - _framework_name = "djl-deepspeed" - - def __init__( - self, - model_id: str, - role: str, - number_of_partitions: Optional[int] = None, - device_id: Optional[int] = None, - device_map: Optional[Union[str, Dict[str, str]]] = None, - load_in_8bit: bool = False, - low_cpu_mem_usage: bool = False, - **kwargs, - ): - """Initialize a HuggingFaceAccelerateModel. - - Args: - model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location - containing the uncompressed model artifacts (i.e. not a tar.gz file). - The model artifacts are expected to be in HuggingFace pre-trained model - format (i.e. model should be loadable from the huggingface transformers - from_pretrained api, and should also include tokenizer configs if applicable). - role (str): An AWS IAM role specified with either the name or full ARN. The Amazon - SageMaker training jobs and APIs that create Amazon SageMaker - endpoints use this role to access model artifacts. After the endpoint is created, - the inference code - might use the IAM role, if it needs to access an AWS resource. - number_of_partitions (int): The number of GPUs to partition the model across. The - partitioning strategy is determined by the device_map setting. If device_map is - not specified, the default HuggingFace strategy will be used. - device_id (int): The device_id to use for instantiating the model. If provided, - the model will only be instantiated once on the indicated device. Do not set this - if you have also specified data_parallel_degree. Defaults to None. - device_map (str or dict): The HuggingFace accelerate device_map to use. Defaults to - None. - load_in_8bit (bool): Whether to load the model in int8 precision using bits and bytes - quantization. This is only supported for select model architectures. - Defaults to False. If ``dtype`` is int8, then this is set to True. - low_cpu_mem_usage (bool): Whether to limit CPU memory usage to 1x model size during - model loading. This is an experimental feature in HuggingFace. This is useful when - loading multiple instances of your model in parallel. Defaults to False. - **kwargs: Keyword arguments passed to the superclasses - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model` - - .. tip:: - - You can find additional parameters for initializing this class at - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model`. - """ - if "hf_hub_token" in kwargs: - kwargs.pop("hf_hub_token") - super(HuggingFaceAccelerateModel, self).__init__( - model_id, - role, - number_of_partitions=number_of_partitions, - **kwargs, - ) - self.device_id = device_id - self.device_map = device_map - self.load_in_8bit = load_in_8bit - self.low_cpu_mem_usage = low_cpu_mem_usage - - def generate_serving_properties(self, serving_properties=None) -> Dict[str, str]: - """Generates the DJL Serving configuration to use for the model. - - The configuration is generated using the arguments passed to the Model during - initialization. If a serving.properties file is found in ``self.source_dir``, - those configuration as merged with the Model parameters, with Model parameters taking - priority. - - Args: - serving_properties: Dictionary containing existing model server configuration - obtained from ``self.source_dir``. Defaults to None. - - Returns: - dict: The model server configuration to use when deploying this model to SageMaker. - """ - serving_properties = super(HuggingFaceAccelerateModel, self).generate_serving_properties( - serving_properties=serving_properties - ) - if self.device_id: - if self.number_of_partitions > 1: - raise ValueError("device_id cannot be set when number_of_partitions is > 1") - serving_properties["option.device_id"] = self.device_id - if self.device_map: - serving_properties["option.device_map"] = self.device_map - if self.load_in_8bit: - if self.dtype != "int8": - raise ValueError("Set dtype='int8' to use load_in_8bit") - serving_properties["option.load_in_8bit"] = self.load_in_8bit - if self.dtype == "int8": - serving_properties["option.load_in_8bit"] = True - if self.low_cpu_mem_usage: - serving_properties["option.low_cpu_mem_usage"] = self.low_cpu_mem_usage - # This is a workaround due to a bug in our built in handler for huggingface - # TODO: Remove this logic whenever 0.20.0 image is out of service - if ( - serving_properties["option.entryPoint"] == "djl_python.huggingface" - and self.dtype - and self.dtype != "auto" - and self.djl_version - and int(self.djl_version.split(".")[1]) < 21 - ): - serving_properties["option.dtype"] = "auto" - serving_properties.pop("option.load_in_8bit", None) - return serving_properties - - def partition( - self, - instance_type: str, - s3_output_uri: str = None, - s3_output_prefix: str = "aot-partitioned-checkpoints", - job_name: Optional[str] = None, - volume_size: int = 30, - volume_kms_key: Optional[str] = None, - output_kms_key: Optional[str] = None, - use_spot_instances: bool = False, - max_wait: int = None, - enable_network_isolation: bool = False, - ): - """Partitions the model using SageMaker Training Job. This is a synchronous API call. - - Args: - instance_type (str): The EC2 instance type to partition this Model. - For example, 'ml.p4d.24xlarge'. - s3_output_uri (str): S3 location for saving the training result (model - artifacts and output files). If not specified, results are - stored to a default bucket. If the bucket with the specific name - does not exist, it will be created. - s3_output_prefix (str): Name of the prefix where all the partitioned - checkpoints to be uploaded. If not provided, the default value is - aot-partitioned-checkpoints. - job_name (str): Training job name. If not specified, a unique training job - name will be created. - volume_size (int): Size in GB of the storage volume to use for - storing input and output data during training (default: 30). - volume_kms_key (str): Optional. KMS key ID for encrypting EBS - volume attached to the training instance (default: None). - output_kms_key (str): Optional. KMS key ID for encrypting the - training output (default: None). - use_spot_instances (bool): Specifies whether to use SageMaker - Managed Spot instances for training. If enabled then the - ``max_wait`` arg should also be set. - - More information: - https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html - (default: ``False``). - max_wait (int): Timeout in seconds waiting for spot training - job (default: None). After this amount of time Amazon - SageMaker will stop waiting for managed spot training job to - complete (default: None). - enable_network_isolation (bool): Specifies whether container will - run in network isolation mode (default: ``False``). Network - isolation mode restricts the container access to outside networks - (such as the Internet). The container does not make any inbound or - outbound network calls. Also known as Internet-free mode. - Returns: - None - """ - - logger.warning( - "HuggingFace engine does not currently support tensor parallelism. " - "Hence ahead of time partitioning is skipped" - ) - - -class FasterTransformerModel(DJLModel): - """A DJL FasterTransformer SageMaker ``Model`` - - This can be deployed to a SageMaker ``Endpoint``. - """ - - _framework_name = "djl-fastertransformer" - - def __init__( - self, - model_id: str, - role: str, - tensor_parallel_degree: Optional[int] = None, - **kwargs, - ): - """Initialize a FasterTransformerModel. - - Args: - model_id (str): This is either the HuggingFace Hub model_id, or the Amazon S3 location - containing the uncompressed model artifacts (i.e. not a tar.gz file). - The model artifacts are expected to be in HuggingFace pre-trained model - format (i.e. model should be loadable from the huggingface transformers - from_pretrained api, and should also include tokenizer configs if applicable). - role (str): An AWS IAM role specified with either the name or full ARN. The Amazon - SageMaker training jobs and APIs that create Amazon SageMaker - endpoints use this role to access model artifacts. After the endpoint is created, - the inference code - might use the IAM role, if it needs to access an AWS resource. - tensor_parllel_degree (int): The number of gpus to shard a single instance of the - model across via tensor_parallelism. This should be set to greater than 1 if the - size of the model is larger than the memory available on a single GPU on the - instance. Defaults to None. If not set, no tensor parallel sharding is done. - **kwargs: Keyword arguments passed to the superclasses - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model` - - .. tip:: - - You can find additional parameters for initializing this class at - :class:`~sagemaker.djl_inference.DJLModel`, - :class:`~sagemaker.model.FrameworkModel`, and - :class:`~sagemaker.model.Model`. - """ - if "hf_hub_token" in kwargs: - kwargs.pop("hf_hub_token") - super(FasterTransformerModel, self).__init__( - model_id, - role, - **kwargs, - ) - if self.number_of_partitions and tensor_parallel_degree: - logger.warning( - "Both number_of_partitions and tensor_parallel_degree have been set for " - "FasterTransformerModel." - "These mean the same thing for FasterTransformerModel. Please only set " - "tensor_parallel_degree." - "number_of_partitions will be ignored" - ) - self.number_of_partitions = tensor_parallel_degree or self.number_of_partitions diff --git a/src/sagemaker/enums.py b/src/sagemaker/enums.py index 5b4d0d6790..f8c618620b 100644 --- a/src/sagemaker/enums.py +++ b/src/sagemaker/enums.py @@ -28,3 +28,24 @@ class EndpointType(Enum): INFERENCE_COMPONENT_BASED = ( "InferenceComponentBased" # Amazon SageMaker Inference Component Based Endpoint ) + + +class RoutingStrategy(Enum): + """Strategy for routing https traffics.""" + + RANDOM = "RANDOM" + """The endpoint routes each request to a randomly chosen instance. + """ + LEAST_OUTSTANDING_REQUESTS = "LEAST_OUTSTANDING_REQUESTS" + """The endpoint routes requests to the specific instances that have + more capacity to process them. + """ + + +class Tag(str, Enum): + """Enum class for tag keys to apply to models.""" + + OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name" + SPECULATIVE_DRAFT_MODEL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider" + FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path" + FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name" diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index b67066fcde..a22890e873 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -20,7 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.enums import JumpStartModelType, JumpStartScriptScope from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -30,12 +30,15 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Dict[str, str]: """Retrieves the default container environment variables for the model matching the arguments. @@ -46,6 +49,8 @@ def retrieve_default( retrieve the default environment variables. (Default: None). model_version (str): Optional. The version of the model for which to retrieve the default environment variables. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to + retrieve model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -65,6 +70,9 @@ def retrieve_default( variables specific for the instance type. script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: The variables to use for the model. @@ -78,13 +86,16 @@ def retrieve_default( ) return artifacts._retrieve_default_environment_variables( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, - include_aws_sdk_env_vars, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + include_aws_sdk_env_vars=include_aws_sdk_env_vars, sagemaker_session=sagemaker_session, instance_type=instance_type, script=script, + config_name=config_name, + model_type=model_type, ) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 58a5fabc2f..8cd6410ea0 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -68,6 +68,7 @@ from sagemaker.interactive_apps import SupportedInteractiveAppTypes from sagemaker.interactive_apps.tensorboard import TensorBoardApp from sagemaker.instance_group import InstanceGroup +from sagemaker.model_card.model_card import ModelCard, TrainingDetails from sagemaker.utils import instance_supports_kms from sagemaker.job import _Job from sagemaker.jumpstart.utils import ( @@ -105,6 +106,8 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline +from sagemaker.telemetry.telemetry_logging import _telemetry_emitter +from sagemaker.telemetry.constants import Feature logger = logging.getLogger(__name__) @@ -182,6 +185,8 @@ def __init__( disable_output_compression: bool = False, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, + instance_placement_config: Optional[Dict] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -274,7 +279,10 @@ def __init__( AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. tags (Optional[Tags]): - Tags for labeling a training job. For more, see + Tags for labeling a training job. These won't be propagated to Models, + Endpoints during :meth:`~sagemaker.estimator.EstimatorBase.deploy`. The + :meth:`~sagemaker.estimator.EstimatorBase.deploy` takes in a seperate + tags parameter. For more on tags, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not specified training job will be created without VPC config. @@ -380,8 +388,8 @@ def __init__( source_dir (str or PipelineVariable): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. The structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. The structure within this directory + is preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. With the following GitHub repo directory structure: @@ -449,6 +457,9 @@ def __init__( A dictionary containing the hyperparameters to initialize this estimator with. (Default: None). + If a source directory is specified, the set_hyperparameters method escapes + the dict argument as JSON, and updates the private hyperparameter attribute. + .. caution:: You must not include any security-sensitive information, such as account access IDs, secrets, and tokens, in the dictionary for configuring @@ -548,6 +559,23 @@ def __init__( Specifies whether RemoteDebug is enabled for the training job. enable_session_tag_chaining (bool or PipelineVariable): Optional. Specifies whether SessionTagChaining is enabled for the training job. + training_plan (str or PipelineVariable): Optional. + Specifies which training plan arn to use for the training job + instance_placement_config (dict): Optional. + Specifies UltraServer placement configuration for the training job + + .. code:: python + + instance_placement_config={ + "EnableMultipleJobs": True, + "PlacementSpecifications":[ + { + "UltraServerId": "ultraserver-1", + "InstanceCount": "2" + } + ] + } + """ instance_count = renamed_kwargs( "train_instance_count", "instance_count", instance_count, kwargs @@ -586,25 +614,36 @@ def __init__( self.dependencies = dependencies or [] self.uploaded_code: Optional[UploadedCode] = None - # Check that the user properly sets both subnet and secutiry_groupe_ids + # Check that the user properly sets both subnet and security_group_ids if ( subnets is not None and security_group_ids is None or security_group_ids is not None and subnets is None ): + troubleshooting = ( + "Refer to this documentation on using custom VPC: " + "https://sagemaker.readthedocs.io/en/v2.24.0/overview.html" + "#secure-training-and-inference-with-vpc" + ) + logger.error("Check troubleshooting guide for common errors: %s", troubleshooting) + raise RuntimeError( "When setting up custom VPC, both subnets and security_group_ids must be set" ) if self.instance_type in ("local", "local_gpu"): if self.instance_type == "local_gpu" and self.instance_count > 1: - raise RuntimeError("Distributed Training in Local GPU is not supported") + raise RuntimeError( + "Distributed Training in Local GPU is not supported." + " Set instance_count to 1." + ) self.sagemaker_session = sagemaker_session or LocalSession() if not isinstance(self.sagemaker_session, sagemaker.local.LocalSession): raise RuntimeError( "instance_type local or local_gpu is only supported with an" - "instance of LocalSession" + "instance of LocalSession. More details on local mode: " + "https://sagemaker.readthedocs.io/en/stable/overview.html#local-mode" ) else: self.sagemaker_session = sagemaker_session or Session() @@ -627,7 +666,11 @@ def __init__( and not is_pipeline_variable(output_path) and output_path.startswith("file://") ): - raise RuntimeError("file:// output paths are only supported in Local Mode") + raise RuntimeError( + "The 'file://' output paths are only supported when using Local Mode. " + "To resolve this issue, ensure you're running in Local Mode with a LocalSession, " + "or use an 's3://' output path for jobs running on SageMaker instances." + ) self.output_path = output_path self.latest_training_job = None self.jobs = [] @@ -642,7 +685,12 @@ def __init__( # Now we marked that as Optional because we can fetch it from SageMakerConfig # Because of marking that parameter as optional, we should validate if it is None, even # after fetching the config. - raise ValueError("An AWS IAM role is required to create an estimator.") + raise ValueError( + "An AWS IAM role is required to create an estimator. " + "Please provide a valid `role` argument with the ARN of an IAM role" + " that has the necessary SageMaker permissions." + ) + self.output_kms_key = resolve_value_from_config( output_kms_key, TRAINING_JOB_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session ) @@ -736,8 +784,7 @@ def __init__( self.tensorboard_output_config = tensorboard_output_config - self.debugger_rule_configs = None - self.collection_configs = None + self.debugger_rule_configs, self.collection_configs = None, None self.enable_sagemaker_metrics = enable_sagemaker_metrics @@ -748,6 +795,7 @@ def __init__( sagemaker_session=self.sagemaker_session, ) + self.profiler_rule_configs, self.profiler_rules = None, None self.profiler_config = profiler_config self.disable_profiler = resolve_value_from_config( direct_input=disable_profiler, @@ -770,8 +818,6 @@ def __init__( ) or _instance_type_supports_profiler(self.instance_type): self.disable_profiler = True - self.profiler_rule_configs = None - self.profiler_rules = None self.debugger_rules = None self.disable_output_compression = disable_output_compression validate_source_code_input_against_pipeline_variables( @@ -781,6 +827,10 @@ def __init__( enable_network_isolation=self._enable_network_isolation, ) + self.training_plan = training_plan + + self.instance_placement_config = instance_placement_config + # Internal flag self._is_output_path_set_from_default_bucket_and_prefix = False @@ -873,6 +923,30 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A } return hyperparameters + @staticmethod + def _nova_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, Any]: + """Applies JSON encoding for Nova job hyperparameters, preserving string values. + + For Nova jobs, string values should not be JSON-encoded. + + Args: + hyperparameters (dict): Dictionary of hyperparameters. + + Returns: + dict: Dictionary with encoded hyperparameters. + """ + current_hyperparameters = hyperparameters + if current_hyperparameters is not None: + hyperparameters = {} + for k, v in current_hyperparameters.items(): + if is_pipeline_variable(v): + hyperparameters[str(k)] = v.to_string() + elif isinstance(v, str): + hyperparameters[str(k)] = v + else: + hyperparameters[str(k)] = json.dumps(v) + return hyperparameters + def _prepare_for_training(self, job_name=None): """Set any values in the estimator that need to be set before training. @@ -906,7 +980,11 @@ def _prepare_for_training(self, job_name=None): self.source_dir = updated_paths["source_dir"] self.dependencies = updated_paths["dependencies"] - if self.source_dir or self.entry_point or self.dependencies: + if ( + self.source_dir + or self.entry_point + or (self.dependencies and len(self.dependencies) > 0) + ): # validate source dir will raise a ValueError if there is something wrong with # the source directory. We are intentionally not handling it because this is a # critical error. @@ -1272,6 +1350,7 @@ def latest_job_profiler_artifacts_path(self): ) return None + @_telemetry_emitter(feature=Feature.ESTIMATOR, func_name="estimator.fit") @runnable_by_pipeline def fit( self, @@ -1342,8 +1421,20 @@ def fit( experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config) self.jobs.append(self.latest_training_job) + forward_to_mlflow_tracking_server = False + if os.environ.get("MLFLOW_TRACKING_URI") and self.enable_network_isolation(): + wait = True + forward_to_mlflow_tracking_server = True if wait: self.latest_training_job.wait(logs=logs) + try: + if forward_to_mlflow_tracking_server: + from sagemaker.mlflow.forward_sagemaker_metrics import log_sagemaker_job_to_mlflow + + log_sagemaker_job_to_mlflow(self.latest_training_job.name) + except ImportError: + if forward_to_mlflow_tracking_server: + raise ValueError("Unable to import mlflow, check if sagemaker-mlflow is installed") def _compilation_job_name(self): """Placeholder docstring""" @@ -1724,6 +1815,8 @@ def register( data_input_configuration=None, skip_model_validation=None, source_uri=None, + model_life_cycle=None, + model_card=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1772,6 +1865,9 @@ def register( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1791,8 +1887,17 @@ def register( else: if "model_kms_key" not in kwargs: kwargs["model_kms_key"] = self.output_kms_key - model = self.create_model(image_uri=image_uri, **kwargs) + model = self.create_model(image_uri=image_uri, name=model_name, **kwargs) model.name = model_name + if self.model_data is not None and model_card is None: + training_details = TrainingDetails.from_model_s3_artifacts( + model_artifacts=[self.model_data], sagemaker_session=self.sagemaker_session + ) + model_card = ModelCard( + name="estimator_card", + training_details=training_details, + sagemaker_session=self.sagemaker_session, + ) return model.register( content_types, response_types, @@ -1817,6 +1922,8 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + model_life_cycle=model_life_cycle, ) @property @@ -1838,6 +1945,8 @@ def model_data(self): if compression_type not in {"GZIP", "NONE"}: raise ValueError( f'Unrecognized training job output data compression type "{compression_type}"' + '. Please specify either "GZIP" or "NONE" as valid options for ' + "the compression type." ) # model data is in uncompressed form NOTE SageMaker Hosting mandates presence of # trailing forward slash in S3 model data URI, so append one if necessary. @@ -1903,6 +2012,14 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na "KeepAlivePeriodInSeconds" ] + if "TrainingPlanArn" in job_details["ResourceConfig"]: + init_params["training_plan"] = job_details["ResourceConfig"]["TrainingPlanArn"] + + if "InstancePlacementConfig" in job_details["ResourceConfig"]: + init_params["instance_placement_config"] = job_details["ResourceConfig"][ + "InstancePlacementConfig" + ] + has_hps = "HyperParameters" in job_details init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {} @@ -2452,6 +2569,11 @@ def start_new(cls, estimator, inputs, experiment_config): return cls(estimator.sagemaker_session, estimator._current_job_name) + @classmethod + def get_train_args(cls, estimator, inputs, experiment_config): + """A public function which is same as _get_train_args function.""" + return cls._get_train_args(estimator, inputs, experiment_config) + @classmethod def _get_train_args(cls, estimator, inputs, experiment_config): """Constructs a dict of arguments for an Amazon SageMaker training job from the estimator. @@ -2487,7 +2609,6 @@ def _get_train_args(cls, estimator, inputs, experiment_config): raise ValueError( "File URIs are supported in local mode only. Please use a S3 URI instead." ) - config = _Job._load_config(inputs, estimator) current_hyperparameters = estimator.hyperparameters() @@ -2783,6 +2904,8 @@ def __init__( enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, + instance_placement_config: Optional[Dict] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -3148,6 +3271,22 @@ def __init__( Specifies whether RemoteDebug is enabled for the training job enable_session_tag_chaining (bool or PipelineVariable): Optional. Specifies whether SessionTagChaining is enabled for the training job + training_plan (str or PipelineVariable): Optional. + Specifies which training plan arn to use for the training job + instance_placement_config (dict): Optional. + Specifies UltraServer placement configuration for the training job + + .. code:: python + + instance_placement_config={ + "EnableMultipleJobs": True, + "PlacementSpecifications":[ + { + "UltraServerId": "ultraserver-1", + "InstanceCount": "2" + } + ] + } """ self.image_uri = image_uri self._hyperparameters = hyperparameters.copy() if hyperparameters else {} @@ -3201,6 +3340,8 @@ def __init__( disable_output_compression=disable_output_compression, enable_remote_debug=enable_remote_debug, enable_session_tag_chaining=enable_session_tag_chaining, + training_plan=training_plan, + instance_placement_config=instance_placement_config, **kwargs, ) @@ -3354,8 +3495,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. @@ -3510,7 +3651,11 @@ def __init__( git_config=git_config, enable_network_isolation=enable_network_isolation, ) - if not is_pipeline_variable(entry_point) and entry_point.startswith("s3://"): + if ( + not is_pipeline_variable(entry_point) + and entry_point is not None + and entry_point.startswith("s3://") + ): raise ValueError( "Invalid entry point script: {}. Must be a path to a local file.".format( entry_point @@ -3530,6 +3675,7 @@ def __init__( self.checkpoint_s3_uri = checkpoint_s3_uri self.checkpoint_local_path = checkpoint_local_path self.enable_sagemaker_metrics = enable_sagemaker_metrics + self.is_nova_job = kwargs.get("is_nova_job", False) def _prepare_for_training(self, job_name=None): """Set hyperparameters needed for training. This method will also validate ``source_dir``. @@ -3644,7 +3790,10 @@ def _model_entry_point(self): def set_hyperparameters(self, **kwargs): """Escapes the dict argument as JSON, updates the private hyperparameter attribute.""" - self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(kwargs)) + if self.is_nova_job: + self._hyperparameters.update(EstimatorBase._nova_encode_hyperparameters(kwargs)) + else: + self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(kwargs)) def hyperparameters(self): """Returns the hyperparameters as a dictionary to use for training. @@ -3655,7 +3804,10 @@ def hyperparameters(self): Returns: dict[str, str]: The hyperparameters. """ - return EstimatorBase._json_encode_hyperparameters(self._hyperparameters) + if self.is_nova_job: + return EstimatorBase._nova_encode_hyperparameters(self._hyperparameters) + else: + return EstimatorBase._json_encode_hyperparameters(self._hyperparameters) @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): diff --git a/src/sagemaker/experiments/_metrics.py b/src/sagemaker/experiments/_metrics.py index 31dd679cc8..026e73e8a6 100644 --- a/src/sagemaker/experiments/_metrics.py +++ b/src/sagemaker/experiments/_metrics.py @@ -197,8 +197,8 @@ def _send_metrics(self, metrics): response = self._metrics_client.batch_put_metrics(**request) errors = response["Errors"] if "Errors" in response else None if errors: - message = errors[0]["Message"] - raise Exception(f'{len(errors)} errors with message "{message}"') + error_code = errors[0]["Code"] + raise Exception(f'{len(errors)} errors with error code "{error_code}"') def _construct_batch_put_metrics_request(self, batch): """Creates dictionary object used as request to metrics service.""" diff --git a/src/sagemaker/feature_store/dataset_builder.py b/src/sagemaker/feature_store/dataset_builder.py index 289fa1ee0c..fc9f9372b1 100644 --- a/src/sagemaker/feature_store/dataset_builder.py +++ b/src/sagemaker/feature_store/dataset_builder.py @@ -929,7 +929,7 @@ def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str: selected_features += ", " selected_features += ", ".join( [ - f'fg_{i}."{feature_name}" as "{feature_name}.{(i+1)}"' + f'fg_{i}."{feature_name}" as "{feature_name}.{(i + 1)}"' for feature_name in feature_group.projected_feature_names ] ) diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 39915b60dc..4eb8d82b0c 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -631,7 +631,7 @@ def __str__(self) -> str: class FeatureGroup: """FeatureGroup definition. - This class instantiates a FeatureGroup object that comprises of a name for the FeatureGroup, + This class instantiates a FeatureGroup object that comprises a name for the FeatureGroup, session instance, and a list of feature definition objects i.e., FeatureDefinition. Attributes: diff --git a/src/sagemaker/feature_store/feature_processor/_input_offset_parser.py b/src/sagemaker/feature_store/feature_processor/_input_offset_parser.py index 17e4139bc6..2b66553ab3 100644 --- a/src/sagemaker/feature_store/feature_processor/_input_offset_parser.py +++ b/src/sagemaker/feature_store/feature_processor/_input_offset_parser.py @@ -72,14 +72,16 @@ def get_offset_datetime(self, offset: Optional[str]) -> datetime: return self.now + offset_td - def get_offset_date_year_month_day_hour(self, offset: Optional[str]) -> Tuple[str]: + def get_offset_date_year_month_day_hour( + self, offset: Optional[str] + ) -> Tuple[str, str, str, str]: """Get the year, month, day and hour based on offset diff. Args: offset (Optional[str]): Offset that is used for target date calcluation. Returns: - Tuple[str]: A tuple that consists of extracted year, month, day, hour from offset date. + Tuple[str, str, str, str]: A tuple that consists of extracted year, month, day, hour from offset date. """ if offset is None: return (None, None, None, None) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index cf9291a139..4a00b2dbc1 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -10,30 +10,29 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Utility methods used by framework classes""" +"""Utility methods used by framework classes.""" from __future__ import absolute_import import json import logging import os import re -import time import shutil import tempfile +import time from collections import namedtuple -from typing import List, Optional, Union, Dict +from typing import Dict, List, Optional, Union + from packaging import version import sagemaker.image_uris +import sagemaker.utils +from sagemaker.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning from sagemaker.instance_group import InstanceGroup from sagemaker.s3_utils import s3_path_join from sagemaker.session_settings import SessionSettings -import sagemaker.utils from sagemaker.workflow import is_pipeline_variable - -from sagemaker.deprecations import renamed_warning, renamed_kwargs from sagemaker.workflow.entities import PipelineVariable -from sagemaker.deprecations import deprecation_warn_base logger = logging.getLogger(__name__) @@ -41,6 +40,7 @@ UploadedCode = namedtuple("UploadedCode", ["s3_prefix", "script_name"]) """sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name. + This is for the source code used for the entry point with an ``Estimator``. It can be instantiated with positional or keyword arguments. """ @@ -142,26 +142,9 @@ "2.1.0", "2.1.2", "2.2.0", - "2.3.0", ], } -PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [ - "1.10", - "1.10.0", - "1.10.2", - "1.11", - "1.11.0", - "1.12", - "1.12.0", - "1.12.1", - "1.13.1", - "2.0.0", - "2.0.1", - "2.1.0", - "2.2.0", -] - TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [ "1.13.1", "2.0.0", @@ -170,6 +153,9 @@ "2.1.2", "2.2.0", "2.3.0", + "2.3.1", + "2.4.1", + "2.5.1", ] TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] @@ -226,7 +212,7 @@ def validate_source_code_input_against_pipeline_variables( git_config: Optional[Dict[str, str]] = None, enable_network_isolation: Union[bool, PipelineVariable] = False, ): - """Validate source code input against pipeline variables + """Validate source code input against pipeline variables. Args: entry_point (str or PipelineVariable): The path to the local Python source file that @@ -267,7 +253,7 @@ def validate_source_code_input_against_pipeline_variables( logger.warning( "The source_dir is a pipeline variable: %s. During pipeline execution, " "the interpreted value of source_dir has to be an S3 URI and " - "must point to a tar.gz file", + "must point to a file with name ``sourcedir.tar.gz``", type(source_dir), ) @@ -496,7 +482,7 @@ def tar_and_upload_dir( def _list_files_to_compress(script, directory): - """Placeholder docstring""" + """Placeholder docstring.""" if directory is None: return [script] @@ -600,7 +586,6 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image): The location returned is a potential concatenation of 2 parts 1. code_location_key_prefix if it exists 2. model_name or a name derived from the image - Args: code_location_key_prefix (str): the s3 key prefix from code_location model_name (str): the name of the model @@ -635,8 +620,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution "enabled": True } } - - """ if training_instance_type == "local" or distribution is None: return @@ -661,7 +644,7 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution def profiler_config_deprecation_warning( profiler_config, image_uri, framework_name, framework_version ): - """Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0""" + """Deprecation message if framework profiling is specified TF >= 2.12 and PT >= 2.0.""" if profiler_config is None or profiler_config.framework_profile_params is None: return @@ -707,6 +690,7 @@ def validate_smdistributed( framework_name (str): A string representing the name of framework selected. framework_version (str): A string representing the framework version selected. py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` distribution (dict): A dictionary with information to enable distributed training. (Defaults to None if distributed training is not enabled.) For example: @@ -778,7 +762,8 @@ def _validate_smdataparallel_args( instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge` framework_name (str): A string representing the name of framework selected. Ex: `tensorflow` framework_version (str): A string representing the framework version selected. Ex: `2.3.1` - py_version (str): A string representing the python version selected. Ex: `py3` + py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` distribution (dict): A dictionary with information to enable distributed training. (Defaults to None if distributed training is not enabled.) Ex: @@ -795,7 +780,6 @@ def _validate_smdataparallel_args( Raises: ValueError: if - (`instance_type` is not in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES or `py_version` is not python3 or `framework_version` is not in SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSION """ @@ -806,17 +790,10 @@ def _validate_smdataparallel_args( if not smdataparallel_enabled: return - is_instance_type_supported = instance_type in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES - err_msg = "" - if not is_instance_type_supported: - # instance_type is required - err_msg += ( - f"Provided instance_type {instance_type} is not supported by smdataparallel.\n" - "Please specify one of the supported instance types:" - f"{SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES}\n" - ) + if not instance_type: + err_msg += "Please specify an instance_type for smdataparallel.\n" if not image_uri: # ignore framework_version & py_version if image_uri is set @@ -870,6 +847,7 @@ def validate_distribution( framework_name (str): A string representing the name of framework selected. framework_version (str): A string representing the framework version selected. py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` image_uri (str): A string representing a Docker image URI. kwargs(dict): Additional kwargs passed to this function @@ -928,13 +906,6 @@ def validate_distribution( ) if framework_name and framework_name == "pytorch": # We need to validate only for PyTorch framework - validate_pytorch_distribution( - distribution=validated_distribution, - framework_name=framework_name, - framework_version=framework_version, - py_version=py_version, - image_uri=image_uri, - ) validate_torch_distributed_distribution( instance_type=instance_type, distribution=validated_distribution, @@ -968,13 +939,6 @@ def validate_distribution( ) if framework_name and framework_name == "pytorch": # We need to validate only for PyTorch framework - validate_pytorch_distribution( - distribution=validated_distribution, - framework_name=framework_name, - framework_version=framework_version, - py_version=py_version, - image_uri=image_uri, - ) validate_torch_distributed_distribution( instance_type=instance_type, distribution=validated_distribution, @@ -990,7 +954,7 @@ def validate_distribution( def validate_distribution_for_instance_type(instance_type, distribution): - """Check if the provided distribution strategy is supported for the instance_type + """Check if the provided distribution strategy is supported for the instance_type. Args: instance_type (str): A string representing the type of training instance selected. @@ -1023,63 +987,6 @@ def validate_distribution_for_instance_type(instance_type, distribution): raise ValueError(err_msg) -def validate_pytorch_distribution( - distribution, framework_name, framework_version, py_version, image_uri -): - """Check if pytorch distribution strategy is correctly invoked by the user. - - Args: - distribution (dict): A dictionary with information to enable distributed training. - (Defaults to None if distributed training is not enabled.) For example: - - .. code:: python - - { - "pytorchddp": { - "enabled": True - } - } - framework_name (str): A string representing the name of framework selected. - framework_version (str): A string representing the framework version selected. - py_version (str): A string representing the python version selected. - image_uri (str): A string representing a Docker image URI. - - Raises: - ValueError: if - `py_version` is not python3 or - `framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS - """ - if framework_name and framework_name != "pytorch": - # We need to validate only for PyTorch framework - return - - pytorch_ddp_enabled = False - if "pytorchddp" in distribution: - pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) - if not pytorch_ddp_enabled: - # Distribution strategy other than pytorchddp is selected - return - - err_msg = "" - if not image_uri: - # ignore framework_version and py_version if image_uri is set - # in case image_uri is not set, then both are mandatory - if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS: - err_msg += ( - f"Provided framework_version {framework_version} is not supported by" - " pytorchddp.\n" - "Please specify one of the supported framework versions:" - f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n" - ) - if "py3" not in py_version: - err_msg += ( - f"Provided py_version {py_version} is not supported by pytorchddp.\n" - "Please specify py_version>=py3" - ) - if err_msg: - raise ValueError(err_msg) - - def validate_torch_distributed_distribution( instance_type, distribution, @@ -1104,6 +1011,7 @@ def validate_torch_distributed_distribution( } framework_version (str): A string representing the framework version selected. py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` image_uri (str): A string representing a Docker image URI. entry_point (str or PipelineVariable): The absolute or relative path to the local Python source file that should be executed as the entry point to @@ -1155,7 +1063,7 @@ def validate_torch_distributed_distribution( ) # Check entry point type - if not entry_point.endswith(".py"): + if entry_point is not None and not entry_point.endswith(".py"): err_msg += ( "Unsupported entry point type for the distribution torch_distributed.\n" "Only python programs (*.py) are supported." @@ -1166,7 +1074,7 @@ def validate_torch_distributed_distribution( def _is_gpu_instance(instance_type): - """Returns bool indicating whether instance_type supports GPU + """Returns bool indicating whether instance_type supports GPU. Args: instance_type (str): Name of the instance_type to check against. @@ -1185,7 +1093,7 @@ def _is_gpu_instance(instance_type): def _is_trainium_instance(instance_type): - """Returns bool indicating whether instance_type is a Trainium instance + """Returns bool indicating whether instance_type is a Trainium instance. Args: instance_type (str): Name of the instance_type to check against. @@ -1201,7 +1109,7 @@ def _is_trainium_instance(instance_type): def python_deprecation_warning(framework, latest_supported_version): - """Placeholder docstring""" + """Placeholder docstring.""" return PYTHON_2_DEPRECATION_WARNING.format( framework=framework, latest_supported_version=latest_supported_version ) @@ -1215,7 +1123,6 @@ def _region_supports_debugger(region_name): Returns: bool: Whether or not the region supports Amazon SageMaker Debugger. - """ return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS @@ -1228,7 +1135,6 @@ def _region_supports_profiler(region_name): Returns: bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature. - """ return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS @@ -1256,7 +1162,8 @@ def validate_version_or_image_args(framework_version, py_version, image_uri): Args: framework_version (str): The version of the framework. - py_version (str): The version of Python. + py_version (str): A string representing the python version selected. + Ex: `py38, py39, py310, py311` image_uri (str): The URI of the image. Raises: @@ -1288,9 +1195,8 @@ def create_image_uri( instance_type (str): SageMaker instance type. Used to determine device type (cpu/gpu/family-specific optimized). framework_version (str): The version of the framework. - py_version (str): Optional. Python version. If specified, should be one - of 'py2' or 'py3'. If not specified, image uri will not include a - python component. + py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`. + If not specified, image uri will not include a python component. account (str): AWS account that contains the image. (default: '520713654638') accelerator_type (str): SageMaker Elastic Inference accelerator type. diff --git a/src/sagemaker/git_utils.py b/src/sagemaker/git_utils.py index 49d151a00b..25e745446a 100644 --- a/src/sagemaker/git_utils.py +++ b/src/sagemaker/git_utils.py @@ -14,14 +14,78 @@ from __future__ import absolute_import import os -from pathlib import Path +import re import subprocess import tempfile import warnings +from pathlib import Path +from urllib.parse import urlparse + import six from six.moves import urllib +def _sanitize_git_url(repo_url): + """Sanitize Git repository URL to prevent URL injection attacks. + + Args: + repo_url (str): The Git repository URL to sanitize + + Returns: + str: The sanitized URL + + Raises: + ValueError: If the URL contains suspicious patterns that could indicate injection + """ + at_count = repo_url.count("@") + + if repo_url.startswith("git@"): + # git@ format requires exactly one @ + if at_count != 1: + raise ValueError("Invalid SSH URL format: git@ URLs must have exactly one @ symbol") + elif repo_url.startswith("ssh://"): + # ssh:// format can have 0 or 1 @ symbols + if at_count > 1: + raise ValueError("Invalid SSH URL format: multiple @ symbols detected") + elif repo_url.startswith("https://") or repo_url.startswith("http://"): + # HTTPS format allows 0 or 1 @ symbols + if at_count > 1: + raise ValueError("Invalid HTTPS URL format: multiple @ symbols detected") + + # Check for invalid characters in the URL before parsing + # These characters should not appear in legitimate URLs + invalid_chars = ["<", ">", "[", "]", "{", "}", "\\", "^", "`", "|"] + for char in invalid_chars: + if char in repo_url: + raise ValueError("Invalid characters in hostname") + + try: + parsed = urlparse(repo_url) + + # Check for suspicious characters in hostname that could indicate injection + if parsed.hostname: + # Check for URL-encoded characters that might be used for obfuscation + suspicious_patterns = ["%25", "%40", "%2F", "%3A"] # encoded %, @, /, : + for pattern in suspicious_patterns: + if pattern in parsed.hostname.lower(): + raise ValueError(f"Suspicious URL encoding detected in hostname: {pattern}") + + # Validate that the hostname looks legitimate + if not re.match(r"^[a-zA-Z0-9.-]+$", parsed.hostname): + raise ValueError("Invalid characters in hostname") + + except Exception as e: + if isinstance(e, ValueError): + raise + raise ValueError(f"Failed to parse URL: {str(e)}") + else: + raise ValueError( + "Unsupported URL scheme: only https://, http://, git@, and ssh:// are allowed" + ) + + return repo_url + + def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None): """Git clone repo containing the training code and serving code. @@ -87,6 +151,10 @@ def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None): if entry_point is None: raise ValueError("Please provide an entry point.") _validate_git_config(git_config) + + # SECURITY: Sanitize the repository URL to prevent injection attacks + git_config["repo"] = _sanitize_git_url(git_config["repo"]) + dest_dir = tempfile.mkdtemp() _generate_and_run_clone_command(git_config, dest_dir) diff --git a/src/sagemaker/huggingface/estimator.py b/src/sagemaker/huggingface/estimator.py index 86df43d4e9..70cc17b209 100644 --- a/src/sagemaker/huggingface/estimator.py +++ b/src/sagemaker/huggingface/estimator.py @@ -15,17 +15,13 @@ import logging import re -from typing import Optional, Union, Dict +from typing import Dict, Optional, Union -from sagemaker.estimator import Framework, EstimatorBase -from sagemaker.fw_utils import ( - framework_name_from_image, - validate_distribution, -) +from sagemaker.estimator import EstimatorBase, Framework +from sagemaker.fw_utils import framework_name_from_image, validate_distribution from sagemaker.huggingface.model import HuggingFaceModel -from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT - from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig +from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -66,7 +62,7 @@ def __init__( Args: py_version (str): Python version you want to use for executing your model training code. Defaults to ``None``. Required unless ``image_uri`` is provided. If - using PyTorch, the current supported version is ``py36``. If using TensorFlow, + using PyTorch, the current supported version is ``py39``. If using TensorFlow, the current supported version is ``py37``. entry_point (str or PipelineVariable): Path (absolute or relative) to the Python source file which should be executed as the entry point to training. @@ -84,8 +80,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory are + preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on diff --git a/src/sagemaker/huggingface/llm_utils.py b/src/sagemaker/huggingface/llm_utils.py index de5e624dbc..c7a1316760 100644 --- a/src/sagemaker/huggingface/llm_utils.py +++ b/src/sagemaker/huggingface/llm_utils.py @@ -13,7 +13,9 @@ """Functions for generating ECR image URIs for pre-built SageMaker Docker images.""" from __future__ import absolute_import +import os from typing import Optional +import importlib.util import urllib.request from urllib.error import HTTPError, URLError @@ -65,6 +67,20 @@ def get_huggingface_llm_image_uri( image_scope="inference", inference_tool="neuronx", ) + if backend == "huggingface-tei": + return image_uris.retrieve( + "huggingface-tei", + region=region, + version=version, + image_scope="inference", + ) + if backend == "huggingface-tei-cpu": + return image_uris.retrieve( + "huggingface-tei-cpu", + region=region, + version=version, + image_scope="inference", + ) if backend == "lmi": version = version or "0.24.0" return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version) @@ -109,3 +125,26 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = "Did not find model metadata for the following HuggingFace Model ID %s" % model_id ) return hf_model_metadata_json + + +def download_huggingface_model_metadata( + model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None +) -> None: + """Downloads the HuggingFace Model snapshot via HuggingFace API. + + Args: + model_id (str): The HuggingFace Model ID + model_local_path (str): The local path to save the HuggingFace Model snapshot. + hf_hub_token (str): The HuggingFace Hub Token + + Raises: + ImportError: If huggingface_hub is not installed. + """ + if not importlib.util.find_spec("huggingface_hub"): + raise ImportError("Unable to import huggingface_hub, check if huggingface_hub is installed") + + from huggingface_hub import snapshot_download + + os.makedirs(model_local_path, exist_ok=True) + logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path) + snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token) diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index f71dca0ac8..3ca25fb3ce 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -26,12 +26,17 @@ ) from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.session import Session from sagemaker.utils import to_string, format_tags from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -118,7 +123,7 @@ def __init__( pytorch_version: Optional[str] = None, py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = HuggingFacePredictor, + predictor_cls: Optional[Callable] = HuggingFacePredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): @@ -153,7 +158,7 @@ def __init__( If not specified, a default image for PyTorch will be used. If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -213,6 +218,7 @@ def deploy( container_startup_health_check_timeout=None, inference_recommendation_id=None, explainer_config=None, + update_endpoint: Optional[bool] = False, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -291,6 +297,11 @@ def deploy( would like to deploy the model and endpoint with recommended parameters. explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability configuration for use with Amazon SageMaker Clarify. (default: None) + update_endpoint (Optional[bool]): + Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint + and delete resources corresponding to the previous EndpointConfig. Default: False + Note: Currently this is supported for single model endpoints Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -299,7 +310,7 @@ def deploy( - If a wrong type of object is provided as serverless inference config or async inference config Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ @@ -330,10 +341,8 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, inference_recommendation_id=inference_recommendation_id, explainer_config=explainer_config, - endpoint_logging=kwargs.get("endpoint_logging", False), - endpoint_type=kwargs.get("endpoint_type", None), - resources=kwargs.get("resources", None), - managed_instance_scaling=kwargs.get("managed_instance_scaling", None), + update_endpoint=update_endpoint, + **kwargs, ) def register( @@ -361,6 +370,8 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -413,6 +424,9 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -461,6 +475,8 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_life_cycle=model_life_cycle, + model_card=model_card, ) def prepare_container_def( @@ -470,6 +486,7 @@ def prepare_container_def( serverless_inference_config=None, inference_tool=None, accept_eula=None, + model_reference_arn=None, ): """A container definition with framework configuration set in model environment variables. @@ -523,7 +540,9 @@ def prepare_container_def( deploy_image, self.repacked_model_data or self.model_data, deploy_env, + image_config=self.image_config, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 5873e37b9f..86208858de 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -20,7 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.enums import HyperparameterValidationMode +from sagemaker.jumpstart.enums import HyperparameterValidationMode, JumpStartModelType from sagemaker.jumpstart.validators import validate_hyperparameters from sagemaker.session import Session @@ -31,11 +31,14 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Dict[str, str]: """Retrieves the default training hyperparameters for the model matching the given arguments. @@ -46,6 +49,8 @@ def retrieve_default( retrieve the default hyperparameters. (Default: None). model_version (str): The version of the model for which to retrieve the default hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. include_container_hyperparameters (bool): ``True`` if the container hyperparameters @@ -66,6 +71,9 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: The hyperparameters to use for the model. @@ -80,18 +88,22 @@ def retrieve_default( return artifacts._retrieve_default_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, include_container_hyperparameters=include_container_hyperparameters, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) def validate( region: Optional[str] = None, model_id: Optional[str] = None, + hub_arn: Optional[str] = None, model_version: Optional[str] = None, hyperparameters: Optional[dict] = None, validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, @@ -107,6 +119,8 @@ def validate( (Default: None). model_version (str): The version of the model for which to validate hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). hyperparameters (dict): Hyperparameters to validate. (Default: None). validation_mode (HyperparameterValidationMode): Method of validation to use with @@ -148,6 +162,7 @@ def validate( return validate_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, hyperparameters=hyperparameters, validation_mode=validation_mode, region=region, diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 57ce47f94c..8d2f169b31 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -11,7 +11,10 @@ "0.6": "0.6.2", "0.7": "0.7.0", "0.8": "0.8.2", - "1.0": "1.0.0" + "1.0": "1.0.0", + "1.1": "1.1.1", + "1.2": "1.2.0", + "1.3": "1.3.0" }, "versions": { "0.3.1": { @@ -480,6 +483,170 @@ "py_versions": [ "py310" ] + }, + "1.1.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "autogluon-training", + "processors": [ + "cpu", + "gpu" + ], + "py_versions": [ + "py310" + ] + }, + "1.1.1": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "autogluon-training", + "processors": [ + "cpu", + "gpu" + ], + "py_versions": [ + "py311" + ] + }, + "1.2.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "autogluon-training", + "processors": [ + "cpu", + "gpu" + ], + "py_versions": [ + "py311" + ] + }, + "1.3.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "autogluon-training", + "processors": [ + "cpu", + "gpu" + ], + "py_versions": [ + "py311" + ] } } }, @@ -491,7 +658,10 @@ "0.6": "0.6.2", "0.7": "0.7.0", "0.8": "0.8.2", - "1.0": "1.0.0" + "1.0": "1.0.0", + "1.1": "1.1.1", + "1.2": "1.2.0", + "1.3": "1.3.0" }, "versions": { "0.3.1": { @@ -987,6 +1157,178 @@ "py_versions": [ "py310" ] + }, + "1.1.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "autogluon-inference", + "processors": [ + "cpu", + "gpu" + ], + "py_versions": [ + "py310" + ] + }, + "1.1.1": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "autogluon-inference", + "processors": [ + "cpu", + "gpu" + ], + "py_versions": [ + "py311" + ] + }, + "1.2.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "autogluon-inference", + "processors": [ + "cpu", + "gpu" + ], + "py_versions": [ + "py311" + ] + }, + "1.3.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "autogluon-inference", + "processors": [ + "cpu", + "gpu" + ], + "py_versions": [ + "py311" + ] } } } diff --git a/src/sagemaker/image_uri_config/blazingtext.json b/src/sagemaker/image_uri_config/blazingtext.json index eba76fc80c..b1768d7f9b 100644 --- a/src/sagemaker/image_uri_config/blazingtext.json +++ b/src/sagemaker/image_uri_config/blazingtext.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "433757028032" }, diff --git a/src/sagemaker/image_uri_config/clarify.json b/src/sagemaker/image_uri_config/clarify.json index 4058a8803f..3f74ac9493 100644 --- a/src/sagemaker/image_uri_config/clarify.json +++ b/src/sagemaker/image_uri_config/clarify.json @@ -16,6 +16,7 @@ "cn-north-1": "122526803553", "cn-northwest-1": "122578899357", "eu-central-1": "017069133835", + "eu-central-2": "730335477804", "eu-north-1": "763603941244", "eu-south-1": "638885417683", "eu-west-1": "131013547314", @@ -26,6 +27,8 @@ "us-east-1": "205585389593", "us-east-2": "211330385671", "us-gov-west-1": "598674086554", + "us-isof-east-1": "579539705040", + "us-isof-south-1": "411392592546", "us-west-1": "740489534195", "us-west-2": "306415355426" }, diff --git a/src/sagemaker/image_uri_config/djl-deepspeed.json b/src/sagemaker/image_uri_config/djl-deepspeed.json index 133a744f6c..e98e382b0b 100644 --- a/src/sagemaker/image_uri_config/djl-deepspeed.json +++ b/src/sagemaker/image_uri_config/djl-deepspeed.json @@ -28,6 +28,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" @@ -60,6 +62,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" @@ -92,6 +96,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" @@ -124,6 +130,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" @@ -156,6 +164,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" @@ -188,6 +198,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" @@ -220,6 +232,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" @@ -252,6 +266,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" @@ -284,6 +300,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" diff --git a/src/sagemaker/image_uri_config/djl-lmi.json b/src/sagemaker/image_uri_config/djl-lmi.json new file mode 100644 index 0000000000..0a741036c1 --- /dev/null +++ b/src/sagemaker/image_uri_config/djl-lmi.json @@ -0,0 +1,115 @@ +{ + "scope": [ + "inference" + ], + "version_aliases": { + "latest": "0.30.0" + }, + "versions": { + "0.30.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "djl-inference", + "tag_prefix": "0.30.0-lmi12.0.0-cu124" + }, + "0.29.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "djl-inference", + "tag_prefix": "0.29.0-lmi11.0.0-cu124" + }, + "0.28.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "djl-inference", + "tag_prefix": "0.28.0-lmi10.0.0-cu124" + } + } +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/djl-neuronx.json b/src/sagemaker/image_uri_config/djl-neuronx.json index a63acc87e4..1fd7492ff4 100644 --- a/src/sagemaker/image_uri_config/djl-neuronx.json +++ b/src/sagemaker/image_uri_config/djl-neuronx.json @@ -2,19 +2,82 @@ "scope": [ "inference" ], + "version_aliases": { + "latest": "0.29.0" + }, "versions": { + "0.29.0": { + "registries": { + "ap-northeast-1": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-west-1": "763104351884", + "eu-west-3": "763104351884", + "mx-central-1":"637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "djl-inference", + "tag_prefix": "0.29.0-neuronx-sdk2.19.1" + }, + "0.28.0": { + "registries": { + "ap-northeast-1": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-west-1": "763104351884", + "eu-west-3": "763104351884", + "mx-central-1":"637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "djl-inference", + "tag_prefix": "0.28.0-neuronx-sdk2.18.2" + }, "0.27.0": { "registries": { "ap-northeast-1": "763104351884", "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-2": "763104351884", "ca-west-1": "204538143572" }, @@ -27,12 +90,20 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-2": "763104351884", "ca-west-1": "204538143572" }, @@ -45,12 +116,20 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-2": "763104351884", "ca-west-1": "204538143572" }, @@ -63,12 +142,20 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-2": "763104351884", "ca-west-1": "204538143572" }, @@ -81,12 +168,20 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-2": "763104351884", "ca-west-1": "204538143572" }, @@ -99,12 +194,20 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", "eu-west-3": "763104351884", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-2": "763104351884", "ca-west-1": "204538143572" }, diff --git a/src/sagemaker/image_uri_config/djl-tensorrtllm.json b/src/sagemaker/image_uri_config/djl-tensorrtllm.json index e125cbd419..cd1e59bad8 100644 --- a/src/sagemaker/image_uri_config/djl-tensorrtllm.json +++ b/src/sagemaker/image_uri_config/djl-tensorrtllm.json @@ -2,7 +2,112 @@ "scope": [ "inference" ], + "version_aliases": { + "latest": "0.30.0" + }, "versions": { + "0.30.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "djl-inference", + "tag_prefix": "0.30.0-tensorrtllm0.12.0-cu125" + }, + "0.29.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "djl-inference", + "tag_prefix": "0.29.0-tensorrtllm0.11.0-cu124" + }, + "0.28.0": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "djl-inference", + "tag_prefix": "0.28.0-tensorrtllm0.9.0-cu122" + }, "0.27.0": { "registries": { "af-south-1": "626614931356", @@ -28,6 +133,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" @@ -60,6 +167,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" @@ -92,6 +201,8 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", "us-west-2": "763104351884", "ca-west-1": "204538143572" diff --git a/src/sagemaker/image_uri_config/factorization-machines.json b/src/sagemaker/image_uri_config/factorization-machines.json index a97ef3b374..b99927f757 100644 --- a/src/sagemaker/image_uri_config/factorization-machines.json +++ b/src/sagemaker/image_uri_config/factorization-machines.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "174872318107" }, diff --git a/src/sagemaker/image_uri_config/forecasting-deepar.json b/src/sagemaker/image_uri_config/forecasting-deepar.json index 5bff449425..c0b3d77786 100644 --- a/src/sagemaker/image_uri_config/forecasting-deepar.json +++ b/src/sagemaker/image_uri_config/forecasting-deepar.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "156387875391" }, diff --git a/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json b/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json index 9da18c1b56..1c425b37ec 100644 --- a/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json +++ b/src/sagemaker/image_uri_config/huggingface-llm-neuronx.json @@ -4,7 +4,7 @@ "inf2" ], "version_aliases": { - "0.0": "0.0.16" + "0.0": "0.0.28" }, "versions": { "0.0.16": { @@ -12,23 +12,48 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.16", "repository": "huggingface-pytorch-tgi-inference", @@ -41,23 +66,48 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.17", "repository": "huggingface-pytorch-tgi-inference", @@ -70,23 +120,48 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.18", "repository": "huggingface-pytorch-tgi-inference", @@ -99,23 +174,48 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.19", "repository": "huggingface-pytorch-tgi-inference", @@ -128,23 +228,48 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.20", "repository": "huggingface-pytorch-tgi-inference", @@ -157,30 +282,379 @@ "py310" ], "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", "eu-south-2": "503227376785", "eu-west-1": "763104351884", + "eu-west-2": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" }, "tag_prefix": "1.13.1-optimum0.0.21", "repository": "huggingface-pytorch-tgi-inference", "container_version": { "inf2": "ubuntu22.04" } + }, + "0.0.22": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.1.2-optimum0.0.22", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "inf2": "ubuntu22.04" + } + }, + "0.0.23": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.1.2-optimum0.0.23", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "inf2": "ubuntu22.04" + } + }, + "0.0.24": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.1.2-optimum0.0.24", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "inf2": "ubuntu22.04" + } + }, + "0.0.25": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.1.2-optimum0.0.25", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "inf2": "ubuntu22.04" + } + }, + "0.0.27": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.1.2-optimum0.0.27", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "inf2": "ubuntu22.04" + } + }, + "0.0.28": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.1.2-optimum0.0.28", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "inf2": "ubuntu22.04" + } } } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/huggingface-llm.json b/src/sagemaker/image_uri_config/huggingface-llm.json index 10073338e7..58fffa0ed9 100644 --- a/src/sagemaker/image_uri_config/huggingface-llm.json +++ b/src/sagemaker/image_uri_config/huggingface-llm.json @@ -12,7 +12,11 @@ "1.2": "1.2.0", "1.3": "1.3.3", "1.4": "1.4.5", - "2.0": "2.0.0" + "2.0": "2.4.0", + "2.3": "2.3.1", + "3.0": "3.0.1", + "3.2": "3.2.3", + "3.1": "3.1.1" }, "versions": { "0.6.0": { @@ -21,8 +25,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -32,19 +36,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -52,9 +62,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.0.0-tgi0.6.0", "repository": "huggingface-pytorch-tgi-inference", @@ -68,8 +79,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -79,19 +90,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -99,9 +116,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.0.0-tgi0.8.2", "repository": "huggingface-pytorch-tgi-inference", @@ -115,8 +133,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -126,19 +144,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -146,9 +170,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.0.1-tgi0.9.3", "repository": "huggingface-pytorch-tgi-inference", @@ -162,8 +187,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -173,19 +198,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -193,9 +224,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.0.1-tgi1.0.3", "repository": "huggingface-pytorch-tgi-inference", @@ -209,8 +241,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -220,19 +252,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -240,9 +278,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.0.1-tgi1.1.0", "repository": "huggingface-pytorch-tgi-inference", @@ -256,8 +295,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -267,19 +306,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -287,9 +332,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi1.2.0", "repository": "huggingface-pytorch-tgi-inference", @@ -303,8 +349,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -314,19 +360,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -334,9 +386,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi1.3.1", "repository": "huggingface-pytorch-tgi-inference", @@ -350,8 +403,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -361,19 +414,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -381,9 +440,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi1.3.3", "repository": "huggingface-pytorch-tgi-inference", @@ -397,8 +457,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -408,19 +468,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -428,9 +494,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi1.4.0", "repository": "huggingface-pytorch-tgi-inference", @@ -444,8 +511,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -455,19 +522,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -475,9 +548,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "tag_prefix": "2.1.1-tgi1.4.2", "repository": "huggingface-pytorch-tgi-inference", @@ -491,8 +565,62 @@ ], "registries": { "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.1.1-tgi1.4.5", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu121-ubuntu22.04" + } + }, + "2.0.0": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -502,19 +630,133 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", - "eu-west-3": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.1.1-tgi2.0.0", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu121-ubuntu22.04" + } + }, + "2.0.1": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.1.1-tgi2.0.1", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu121-ubuntu22.04" + } + }, + "2.0.2": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -522,24 +764,133 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "tag_prefix": "2.1.1-tgi1.4.5", + "tag_prefix": "2.3.0-tgi2.0.2", "repository": "huggingface-pytorch-tgi-inference", "container_version": { "gpu": "cu121-ubuntu22.04" } }, - "2.0.0": { + "2.2.0": { "py_versions": [ "py310" ], "registries": { "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.3.0-tgi2.2.0", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu121-ubuntu22.04-v2.0" + } + }, + "2.3.1": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.4.0-tgi2.3.1", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04" + } + }, + "2.4.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -549,19 +900,133 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", - "eu-west-3": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.4.0-tgi2.4.0", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04-v2.2" + } + }, + "3.0.1": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.4.0-tgi3.0.1", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04-v2.1" + } + }, + "3.1.1": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -569,16 +1034,125 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "tag_prefix": "2.1.1-tgi2.0.0", + "tag_prefix": "2.6.0-tgi3.1.1", "repository": "huggingface-pytorch-tgi-inference", "container_version": { - "gpu": "cu121-ubuntu22.04" + "gpu": "cu124-ubuntu22.04" + } + }, + "3.2.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.6.0-tgi3.2.0", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04" + } + }, + "3.2.3": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "tag_prefix": "2.6.0-tgi3.2.3", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04" } } } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/huggingface-neuron.json b/src/sagemaker/image_uri_config/huggingface-neuron.json index ae38ce209b..2a68282327 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuron.json +++ b/src/sagemaker/image_uri_config/huggingface-neuron.json @@ -17,18 +17,22 @@ ], "repository": "huggingface-pytorch-inference-neuron", "registries": { + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-neuronx.json b/src/sagemaker/image_uri_config/huggingface-neuronx.json index 0d8b7268b1..d39d58bb9e 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuronx.json +++ b/src/sagemaker/image_uri_config/huggingface-neuronx.json @@ -5,7 +5,10 @@ ], "version_aliases": { "4.28": "4.28.1", - "4.34": "4.34.1" + "4.34": "4.34.1", + "4.36": "4.36.2", + "4.43": "4.43.2", + "4.48": "4.48.1" }, "versions": { "4.28.1": { @@ -18,21 +21,29 @@ ], "repository": "huggingface-pytorch-training-neuronx", "registries": { + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-2": "763104351884", "ca-west-1": "204538143572" }, @@ -54,21 +65,29 @@ ], "repository": "huggingface-pytorch-inference-neuronx", "registries": { + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-south-1": "763104351884", "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-3": "763104351884", "il-central-1": "780543022126", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-2": "763104351884", "ca-west-1": "204538143572" }, @@ -79,6 +98,136 @@ "sdk2.15.0" ] } + }, + "4.36.2": { + "version_aliases": { + "pytorch1.13": "pytorch1.13.1" + }, + "pytorch1.13.1": { + "py_versions": [ + "py310" + ], + "repository": "huggingface-pytorch-inference-neuronx", + "registries": { + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "mx-central-1":"637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "container_version": { + "inf": "ubuntu20.04" + }, + "sdk_versions": [ + "sdk2.18.0" + ] + } + }, + "4.43.2": { + "version_aliases": { + "pytorch2.1": "pytorch2.1.2" + }, + "pytorch2.1.2": { + "py_versions": [ + "py310" + ], + "repository": "huggingface-pytorch-inference-neuronx", + "registries": { + "ap-northeast-1": "763104351884", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "mx-central-1":"637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "container_version": { + "inf": "ubuntu20.04" + }, + "sdk_versions": [ + "sdk2.20.0" + ] + } + }, + "4.48.1": { + "version_aliases": { + "pytorch2.1": "pytorch2.1.2" + }, + "pytorch2.1.2": { + "py_versions": [ + "py310" + ], + "repository": "huggingface-pytorch-inference-neuronx", + "registries": { + "ap-northeast-1": "763104351884", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "mx-central-1":"637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "container_version": { + "inf": "ubuntu20.04" + }, + "sdk_versions": [ + "sdk2.20.0" + ] + } } } }, @@ -89,7 +238,8 @@ "version_aliases": { "4.28": "4.28.1", "4.34": "4.34.1", - "4.36": "4.36.2" + "4.36": "4.36.2", + "4.43": "4.43.2" }, "versions": { "4.28.1": { @@ -105,6 +255,7 @@ "af-south-1": "626614931356", "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -113,6 +264,8 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -125,6 +278,7 @@ "eu-south-1": "692866216735", "eu-south-2": "503227376785", "me-south-1": "217643126080", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -157,6 +311,7 @@ "af-south-1": "626614931356", "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -165,6 +320,8 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -177,6 +334,7 @@ "eu-south-1": "692866216735", "eu-south-2": "503227376785", "me-south-1": "217643126080", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -198,7 +356,8 @@ }, "4.36.2": { "version_aliases": { - "pytorch1.13": "pytorch1.13.1" + "pytorch1.13": "pytorch1.13.1", + "pytorch2.1": "pytorch2.1.2" }, "pytorch1.13.1": { "py_versions": [ @@ -209,6 +368,7 @@ "af-south-1": "626614931356", "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -217,6 +377,8 @@ "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -229,6 +391,7 @@ "eu-south-1": "692866216735", "eu-south-2": "503227376785", "me-south-1": "217643126080", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -246,6 +409,100 @@ "sdk_versions": [ "sdk2.16.1" ] + }, + "pytorch2.1.2": { + "py_versions": [ + "py310" + ], + "repository": "huggingface-pytorch-inference-neuronx", + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "mx-central-1":"637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "container_version": { + "inf": "ubuntu20.04" + }, + "sdk_versions": [ + "sdk2.18.0" + ] + } + }, + "4.43.2": { + "version_aliases": { + "pytorch2.1": "pytorch2.1.2" + }, + "pytorch2.1.2": { + "py_versions": [ + "py310" + ], + "repository": "huggingface-pytorch-inference-neuronx", + "registries": { + "ap-northeast-1": "763104351884", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "mx-central-1":"637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "container_version": { + "inf": "ubuntu20.04" + }, + "sdk_versions": [ + "sdk2.20.0" + ] } } } diff --git a/src/sagemaker/image_uri_config/huggingface-tei-cpu.json b/src/sagemaker/image_uri_config/huggingface-tei-cpu.json new file mode 100644 index 0000000000..3af1ed5de6 --- /dev/null +++ b/src/sagemaker/image_uri_config/huggingface-tei-cpu.json @@ -0,0 +1,203 @@ +{ + "inference": { + "processors": [ + "cpu" + ], + "version_aliases": { + "1.2": "1.2.3", + "1.4": "1.4.0", + "1.6": "1.6.0", + "1.7": "1.7.0" + }, + "versions": { + "1.2.3": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "510948584623", + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-northeast-3": "867004704886", + "ap-south-1": "720646828776", + "ap-south-2": "628508329040", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", + "ap-southeast-4": "106583098589", + "ca-central-1": "341280168497", + "ca-west-1": "190319476487", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-central-2": "680994064768", + "eu-north-1": "662702820516", + "eu-south-1": "978288397137", + "eu-south-2": "104374241257", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "il-central-1": "898809789911", + "me-central-1": "272398656194", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-east-1": "237065988967", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-isob-east-1": "281123927165", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "tag_prefix": "2.0.1-tei1.2.3", + "repository": "tei-cpu", + "container_version": { + "cpu": "ubuntu22.04" + } + }, + "1.4.0":{ + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "510948584623", + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-northeast-3": "867004704886", + "ap-south-1": "720646828776", + "ap-south-2": "628508329040", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", + "ap-southeast-4": "106583098589", + "ca-central-1": "341280168497", + "ca-west-1": "190319476487", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-central-2": "680994064768", + "eu-north-1": "662702820516", + "eu-south-1": "978288397137", + "eu-south-2": "104374241257", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "il-central-1": "898809789911", + "me-central-1": "272398656194", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-east-1": "237065988967", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-isob-east-1": "281123927165", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "tag_prefix": "2.0.1-tei1.4.0", + "repository": "tei-cpu", + "container_version": { + "cpu": "ubuntu22.04" + } + }, + "1.6.0":{ + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "510948584623", + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-northeast-3": "867004704886", + "ap-south-1": "720646828776", + "ap-south-2": "628508329040", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", + "ap-southeast-4": "106583098589", + "ca-central-1": "341280168497", + "ca-west-1": "190319476487", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-central-2": "680994064768", + "eu-north-1": "662702820516", + "eu-south-1": "978288397137", + "eu-south-2": "104374241257", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "il-central-1": "898809789911", + "me-central-1": "272398656194", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-east-1": "237065988967", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-isob-east-1": "281123927165", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "tag_prefix": "2.0.1-tei1.6.0", + "repository": "tei-cpu", + "container_version": { + "cpu": "ubuntu22.04" + } + }, + "1.7.0":{ + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "510948584623", + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-northeast-3": "867004704886", + "ap-south-1": "720646828776", + "ap-south-2": "628508329040", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", + "ap-southeast-4": "106583098589", + "ca-central-1": "341280168497", + "ca-west-1": "190319476487", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-central-2": "680994064768", + "eu-north-1": "662702820516", + "eu-south-1": "978288397137", + "eu-south-2": "104374241257", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "il-central-1": "898809789911", + "me-central-1": "272398656194", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-east-1": "237065988967", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-isob-east-1": "281123927165", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "tag_prefix": "2.0.1-tei1.7.0", + "repository": "tei-cpu", + "container_version": { + "cpu": "ubuntu22.04" + } + } + } + } +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/huggingface-tei.json b/src/sagemaker/image_uri_config/huggingface-tei.json new file mode 100644 index 0000000000..eaf08230c7 --- /dev/null +++ b/src/sagemaker/image_uri_config/huggingface-tei.json @@ -0,0 +1,203 @@ +{ + "inference": { + "processors": [ + "gpu" + ], + "version_aliases": { + "1.2": "1.2.3", + "1.4": "1.4.0", + "1.6": "1.6.0", + "1.7": "1.7.0" + }, + "versions": { + "1.2.3": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "510948584623", + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-northeast-3": "867004704886", + "ap-south-1": "720646828776", + "ap-south-2": "628508329040", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", + "ap-southeast-4": "106583098589", + "ca-central-1": "341280168497", + "ca-west-1": "190319476487", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-central-2": "680994064768", + "eu-north-1": "662702820516", + "eu-south-1": "978288397137", + "eu-south-2": "104374241257", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "il-central-1": "898809789911", + "me-central-1": "272398656194", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-east-1": "237065988967", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-isob-east-1": "281123927165", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "tag_prefix": "2.0.1-tei1.2.3", + "repository": "tei", + "container_version": { + "gpu": "cu122-ubuntu22.04" + } + }, + "1.4.0": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "510948584623", + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-northeast-3": "867004704886", + "ap-south-1": "720646828776", + "ap-south-2": "628508329040", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", + "ap-southeast-4": "106583098589", + "ca-central-1": "341280168497", + "ca-west-1": "190319476487", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-central-2": "680994064768", + "eu-north-1": "662702820516", + "eu-south-1": "978288397137", + "eu-south-2": "104374241257", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "il-central-1": "898809789911", + "me-central-1": "272398656194", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-east-1": "237065988967", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-isob-east-1": "281123927165", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "tag_prefix": "2.0.1-tei1.4.0", + "repository": "tei", + "container_version": { + "gpu": "cu122-ubuntu22.04" + } + }, + "1.6.0": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "510948584623", + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-northeast-3": "867004704886", + "ap-south-1": "720646828776", + "ap-south-2": "628508329040", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", + "ap-southeast-4": "106583098589", + "ca-central-1": "341280168497", + "ca-west-1": "190319476487", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-central-2": "680994064768", + "eu-north-1": "662702820516", + "eu-south-1": "978288397137", + "eu-south-2": "104374241257", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "il-central-1": "898809789911", + "me-central-1": "272398656194", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-east-1": "237065988967", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-isob-east-1": "281123927165", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "tag_prefix": "2.0.1-tei1.6.0", + "repository": "tei", + "container_version": { + "gpu": "cu122-ubuntu22.04" + } + }, + "1.7.0": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "510948584623", + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-northeast-3": "867004704886", + "ap-south-1": "720646828776", + "ap-south-2": "628508329040", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", + "ap-southeast-4": "106583098589", + "ca-central-1": "341280168497", + "ca-west-1": "190319476487", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-central-2": "680994064768", + "eu-north-1": "662702820516", + "eu-south-1": "978288397137", + "eu-south-2": "104374241257", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "il-central-1": "898809789911", + "me-central-1": "272398656194", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-east-1": "237065988967", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-isob-east-1": "281123927165", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "tag_prefix": "2.0.1-tei1.7.0", + "repository": "tei", + "container_version": { + "gpu": "cu122-ubuntu22.04" + } + } + } + } +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/huggingface-training-compiler.json b/src/sagemaker/image_uri_config/huggingface-training-compiler.json index 735e7917b3..c84469acc2 100644 --- a/src/sagemaker/image_uri_config/huggingface-training-compiler.json +++ b/src/sagemaker/image_uri_config/huggingface-training-compiler.json @@ -60,6 +60,7 @@ "af-south-1": "626614931356", "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -69,6 +70,8 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-central-2": "380420809688", @@ -80,6 +83,7 @@ "eu-west-3": "763104351884", "me-south-1": "217643126080", "me-central-1": "914824155844", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -100,6 +104,7 @@ "af-south-1": "626614931356", "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -109,6 +114,8 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-central-2": "380420809688", @@ -120,6 +127,7 @@ "eu-west-3": "763104351884", "me-south-1": "217643126080", "me-central-1": "914824155844", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -145,6 +153,7 @@ "af-south-1": "626614931356", "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -154,6 +163,8 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-central-2": "380420809688", @@ -165,6 +176,7 @@ "eu-west-3": "763104351884", "me-south-1": "217643126080", "me-central-1": "914824155844", + "mx-central-1":"637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 930b24566d..475a82aeec 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -13,7 +13,10 @@ "4.17": "4.17.0", "4.26": "4.26.0", "4.28": "4.28.1", - "4.36": "4.36.0" + "4.36": "4.36.0", + "4.46": "4.46.1", + "4.48": "4.48.0", + "4.49": "4.49.0" }, "versions": { "4.4.2": { @@ -1018,6 +1021,147 @@ "gpu": "cu121-ubuntu20.04" } } + }, + "4.46.1": { + "version_aliases": { + "pytorch2.3": "pytorch2.3.0" + }, + "pytorch2.3.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "huggingface-pytorch-training", + "container_version": { + "gpu": "cu121-ubuntu20.04" + } + } + }, + "4.48.0": { + "version_aliases": { + "pytorch2.3": "pytorch2.3.0" + }, + "pytorch2.3.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "huggingface-pytorch-training", + "container_version": { + "gpu": "cu121-ubuntu20.04" + } + } + }, + "4.49.0": { + "version_aliases": { + "pytorch2.5": "pytorch2.5.1" + }, + "pytorch2.5.1": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "huggingface-pytorch-training", + "container_version": { + "gpu": "cu124-ubuntu22.04" + } + } } } }, @@ -1034,7 +1178,8 @@ "4.17": "4.17.0", "4.26": "4.26.0", "4.28": "4.28.1", - "4.37": "4.37.0" + "4.37": "4.37.0", + "4.49": "4.49.0" }, "versions": { "4.6.1": { @@ -1883,6 +2028,110 @@ "cpu": "ubuntu22.04" } } + }, + "4.48.0": { + "version_aliases": { + "pytorch2.3": "pytorch2.3.0" + }, + "pytorch2.3.0": { + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "huggingface-pytorch-inference", + "container_version": { + "gpu": "cu121-ubuntu22.04", + "cpu": "ubuntu22.04" + } + } + }, + "4.49.0": { + "version_aliases": { + "pytorch2.6": "pytorch2.6.0" + }, + "pytorch2.6.0": { + "py_versions": [ + "py312" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "huggingface-pytorch-inference", + "container_version": { + "gpu": "cu124-ubuntu22.04", + "cpu": "ubuntu22.04" + } + } } } } diff --git a/src/sagemaker/image_uri_config/hyperpod-recipes-neuron.json b/src/sagemaker/image_uri_config/hyperpod-recipes-neuron.json new file mode 100644 index 0000000000..cd5a69bfe2 --- /dev/null +++ b/src/sagemaker/image_uri_config/hyperpod-recipes-neuron.json @@ -0,0 +1,52 @@ +{ + "training": { + "processors": [ + "neuronx" + ], + "version_aliases": { + "2.1.2": "2.1.2" + }, + "versions": { + "2.1.2": { + "py_versions": [ + "py310" + ], + "repository": "pytorch-training-neuronx", + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572", + "ca-central-1": "763104351884" + }, + "container_version": { + "neuronx": "ubuntu20.04" + }, + "sdk_versions": [ + "sdk2.20.2" + ] + } + } + } +} diff --git a/src/sagemaker/image_uri_config/image-classification-neo.json b/src/sagemaker/image_uri_config/image-classification-neo.json index 66c2da8481..09e019a7de 100644 --- a/src/sagemaker/image_uri_config/image-classification-neo.json +++ b/src/sagemaker/image_uri_config/image-classification-neo.json @@ -1,5 +1,7 @@ { - "scope": ["inference"], + "scope": [ + "inference" + ], "versions": { "latest": { "registries": { @@ -15,23 +17,27 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "image-classification-neo" } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/image-classification.json b/src/sagemaker/image_uri_config/image-classification.json index 67c926f779..4fed1dea20 100644 --- a/src/sagemaker/image_uri_config/image-classification.json +++ b/src/sagemaker/image_uri_config/image-classification.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "433757028032" }, diff --git a/src/sagemaker/image_uri_config/inferentia-mxnet.json b/src/sagemaker/image_uri_config/inferentia-mxnet.json index 5a371430ce..8cefb7cf0f 100644 --- a/src/sagemaker/image_uri_config/inferentia-mxnet.json +++ b/src/sagemaker/image_uri_config/inferentia-mxnet.json @@ -1,9 +1,15 @@ { - "processors": ["inf"], - "scope": ["inference"], + "processors": [ + "inf" + ], + "scope": [ + "inference" + ], "versions": { "1.5.1": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -17,26 +23,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-neo-mxnet" }, "1.8": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -50,23 +62,27 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-neo-mxnet" } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/inferentia-pytorch.json b/src/sagemaker/image_uri_config/inferentia-pytorch.json index 479ee29218..db61749216 100644 --- a/src/sagemaker/image_uri_config/inferentia-pytorch.json +++ b/src/sagemaker/image_uri_config/inferentia-pytorch.json @@ -1,9 +1,15 @@ { - "processors": ["inf"], - "scope": ["inference"], + "processors": [ + "inf" + ], + "scope": [ + "inference" + ], "versions": { "1.7": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -17,26 +23,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-neo-pytorch" }, "1.8": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -50,26 +62,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-neo-pytorch" }, "1.9": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -83,23 +101,27 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-neo-pytorch" } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/inferentia-tensorflow.json b/src/sagemaker/image_uri_config/inferentia-tensorflow.json index ca2bbcdb8b..7954d2c11f 100644 --- a/src/sagemaker/image_uri_config/inferentia-tensorflow.json +++ b/src/sagemaker/image_uri_config/inferentia-tensorflow.json @@ -1,9 +1,15 @@ { - "processors": ["inf"], - "scope": ["inference"], + "processors": [ + "inf" + ], + "scope": [ + "inference" + ], "versions": { "1.15.0": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -17,26 +23,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-neo-tensorflow" }, "2.5.2": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -50,23 +62,27 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-neo-tensorflow" } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/instance_gpu_info.json b/src/sagemaker/image_uri_config/instance_gpu_info.json index 9fc005bc47..e64a9bcf88 100644 --- a/src/sagemaker/image_uri_config/instance_gpu_info.json +++ b/src/sagemaker/image_uri_config/instance_gpu_info.json @@ -23,7 +23,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-east-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -49,7 +49,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-northeast-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -75,7 +75,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-northeast-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -101,7 +101,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-northeast-3": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -127,7 +127,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-south-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -153,7 +153,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-southeast-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -179,7 +179,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-southeast-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -205,7 +205,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ap-southeast-3": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -231,7 +231,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "ca-central-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -257,7 +257,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "cn-north-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -283,7 +283,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "cn-northwest-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -309,7 +309,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-central-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -335,7 +335,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-central-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -361,7 +361,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-north-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -387,7 +387,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-south-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -413,7 +413,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-south-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -439,7 +439,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-west-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -465,7 +465,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-west-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -491,7 +491,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "eu-west-3": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -517,7 +517,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "il-central-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -543,7 +543,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "me-central-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -569,7 +569,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "me-south-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -595,7 +595,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "sa-east-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -621,7 +621,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-east-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -647,7 +647,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-east-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -673,7 +673,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-gov-east-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -699,7 +699,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-gov-west-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -725,7 +725,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-west-1": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -751,7 +751,7 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} }, "us-west-2": { "ml.p5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 655360}, @@ -777,6 +777,6 @@ "ml.g5.16xlarge": {"Count": 1, "TotalGpuMemoryInMiB": 24576}, "ml.g5.12xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, "ml.g5.24xlarge": {"Count": 4, "TotalGpuMemoryInMiB": 98304}, - "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 196608} + "ml.g5.48xlarge": {"Count": 8, "TotalGpuMemoryInMiB": 183104} } } \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/ipinsights.json b/src/sagemaker/image_uri_config/ipinsights.json index 8840b01473..175ab06c38 100644 --- a/src/sagemaker/image_uri_config/ipinsights.json +++ b/src/sagemaker/image_uri_config/ipinsights.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "174872318107" }, diff --git a/src/sagemaker/image_uri_config/kmeans.json b/src/sagemaker/image_uri_config/kmeans.json index 9b181a75f5..bffca1e5b7 100644 --- a/src/sagemaker/image_uri_config/kmeans.json +++ b/src/sagemaker/image_uri_config/kmeans.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "174872318107" }, diff --git a/src/sagemaker/image_uri_config/knn.json b/src/sagemaker/image_uri_config/knn.json index 4d561f694d..e36777c8c9 100644 --- a/src/sagemaker/image_uri_config/knn.json +++ b/src/sagemaker/image_uri_config/knn.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "174872318107" }, diff --git a/src/sagemaker/image_uri_config/linear-learner.json b/src/sagemaker/image_uri_config/linear-learner.json index c3dafc49bc..c0e615bae7 100644 --- a/src/sagemaker/image_uri_config/linear-learner.json +++ b/src/sagemaker/image_uri_config/linear-learner.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "174872318107" }, diff --git a/src/sagemaker/image_uri_config/model-monitor.json b/src/sagemaker/image_uri_config/model-monitor.json index 117dbecd84..886ae2611f 100644 --- a/src/sagemaker/image_uri_config/model-monitor.json +++ b/src/sagemaker/image_uri_config/model-monitor.json @@ -18,6 +18,7 @@ "cn-north-1": "453000072557", "cn-northwest-1": "453252182341", "eu-central-1": "048819808253", + "eu-central-2": "590183933784", "eu-north-1": "895015795356", "eu-south-1": "933208885752", "eu-south-2": "437450045455", @@ -30,6 +31,8 @@ "sa-east-1": "539772159869", "us-east-1": "156813124566", "us-east-2": "777275614652", + "us-isof-east-1": "853188333426", + "us-isof-south-1": "467912361380", "us-west-1": "890145073186", "us-west-2": "159807026194" }, diff --git a/src/sagemaker/image_uri_config/neo-mxnet.json b/src/sagemaker/image_uri_config/neo-mxnet.json index 730379a81b..ffab6f5b58 100644 --- a/src/sagemaker/image_uri_config/neo-mxnet.json +++ b/src/sagemaker/image_uri_config/neo-mxnet.json @@ -1,6 +1,11 @@ { - "processors": ["cpu", "gpu"], - "scope": ["inference"], + "processors": [ + "cpu", + "gpu" + ], + "scope": [ + "inference" + ], "version_aliases": { "0.12.1": "1.8", "1.0.0": "1.8", @@ -17,7 +22,9 @@ }, "versions": { "1.8": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -31,23 +38,27 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-inference-mxnet" } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/neo-pytorch.json b/src/sagemaker/image_uri_config/neo-pytorch.json index b6a3e633af..39c4f158c8 100644 --- a/src/sagemaker/image_uri_config/neo-pytorch.json +++ b/src/sagemaker/image_uri_config/neo-pytorch.json @@ -1,6 +1,11 @@ { - "processors": ["cpu", "gpu"], - "scope": ["inference"], + "processors": [ + "cpu", + "gpu" + ], + "scope": [ + "inference" + ], "version_aliases": { "0.4.0": "1.4", "1.0.0": "1.4", @@ -21,7 +26,9 @@ }, "versions": { "1.4": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -35,26 +42,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" }, "1.5": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -68,26 +81,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" }, "1.6": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -101,26 +120,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" }, "1.7": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -134,26 +159,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" }, "1.8": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -167,26 +198,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" }, "1.12": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -200,26 +237,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" }, "1.13": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -233,26 +276,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" }, "2.0": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -266,23 +315,27 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/neo-tensorflow.json b/src/sagemaker/image_uri_config/neo-tensorflow.json index 4fd6c6c58b..2df048167c 100644 --- a/src/sagemaker/image_uri_config/neo-tensorflow.json +++ b/src/sagemaker/image_uri_config/neo-tensorflow.json @@ -1,6 +1,11 @@ { - "processors": ["cpu", "gpu"], - "scope": ["inference"], + "processors": [ + "cpu", + "gpu" + ], + "scope": [ + "inference" + ], "version_aliases": { "1.4.1": "1.15.3", "1.5.0": "1.15.3", @@ -23,7 +28,9 @@ }, "versions": { "1.15.3": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -37,26 +44,32 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-inference-tensorflow" }, "2.9.2": { - "py_versions": ["py3"], + "py_versions": [ + "py3" + ], "registries": { "af-south-1": "774647643957", "ap-east-1": "110948597952", @@ -70,23 +83,27 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "sagemaker-inference-tensorflow" } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/ntm.json b/src/sagemaker/image_uri_config/ntm.json index d753ccec48..fa945c7304 100644 --- a/src/sagemaker/image_uri_config/ntm.json +++ b/src/sagemaker/image_uri_config/ntm.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "174872318107" }, diff --git a/src/sagemaker/image_uri_config/object-detection.json b/src/sagemaker/image_uri_config/object-detection.json index d036f2ff15..56df938108 100644 --- a/src/sagemaker/image_uri_config/object-detection.json +++ b/src/sagemaker/image_uri_config/object-detection.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "433757028032" }, diff --git a/src/sagemaker/image_uri_config/object2vec.json b/src/sagemaker/image_uri_config/object2vec.json index 53f6686945..d652b352ad 100644 --- a/src/sagemaker/image_uri_config/object2vec.json +++ b/src/sagemaker/image_uri_config/object2vec.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "174872318107" }, diff --git a/src/sagemaker/image_uri_config/pca.json b/src/sagemaker/image_uri_config/pca.json index 64792a8e7b..0c32acda60 100644 --- a/src/sagemaker/image_uri_config/pca.json +++ b/src/sagemaker/image_uri_config/pca.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "174872318107" }, diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json index faf7d6a14a..53c2a75e13 100644 --- a/src/sagemaker/image_uri_config/pytorch-smp.json +++ b/src/sagemaker/image_uri_config/pytorch-smp.json @@ -6,8 +6,11 @@ "version_aliases": { "2.0": "2.0.1", "2.1": "2.1.2", - "2.2": "2.3.0", - "2.2.0": "2.3.0" + "2.2": "2.3.1", + "2.2.0": "2.3.1", + "2.3.1": "2.5.0", + "2.4.1": "2.7.0", + "2.5.1": "2.8.0" }, "versions": { "2.0.1": { @@ -109,6 +112,106 @@ "us-west-2": "658645717510" }, "repository": "smdistributed-modelparallel" + }, + "2.3.1": { + "py_versions": [ + "py310" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" + }, + "2.5.0": { + "py_versions": [ + "py311" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" + }, + "2.7.0": { + "py_versions": [ + "py311" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" + }, + "2.8.0": { + "py_versions": [ + "py311" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" } } } diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index f068c68149..8a1993e52a 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -23,17 +23,17 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", - "eu-north-1": "763104351884", "eu-central-2": "380420809688", - "eu-west-1": "763104351884", + "eu-north-1": "763104351884", "eu-south-2": "503227376785", + "eu-west-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference-eia" }, @@ -47,13 +47,13 @@ "ap-south-2": "772153158452", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ca-west-1": "204538143572", "eu-central-2": "380420809688", - "eu-west-1": "763104351884", "eu-south-2": "503227376785", + "eu-west-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference-eia" } @@ -81,7 +81,12 @@ "1.12": "1.12.1", "1.13": "1.13.1", "2.0": "2.0.1", - "2.1": "2.1.0" + "2.1": "2.1.0", + "2.2": "2.2.0", + "2.3": "2.3.0", + "2.4": "2.4.0", + "2.5": "2.5.1", + "2.6": "2.6.0" }, "versions": { "0.4.0": { @@ -193,8 +198,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -204,19 +209,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -225,8 +236,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -237,8 +247,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -248,19 +258,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -269,19 +285,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, "1.4.0": { "py_versions": [ - "py3" + "py3", + "py36" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -291,19 +307,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -312,19 +334,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, "1.5.0": { "py_versions": [ - "py3" + "py3", + "py36" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -334,19 +356,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -355,8 +383,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -367,8 +394,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -378,19 +405,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -399,8 +432,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -411,8 +443,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -422,19 +454,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -443,8 +481,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -455,8 +492,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -466,19 +503,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -487,8 +530,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -499,8 +541,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -510,19 +552,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -531,8 +579,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -542,8 +589,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -553,19 +600,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -574,8 +627,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -585,8 +637,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -596,19 +648,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -617,8 +675,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -628,8 +685,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -639,19 +696,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -660,8 +723,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -671,8 +733,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -682,19 +744,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -703,8 +771,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -714,8 +781,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -725,19 +792,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -746,8 +819,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -757,8 +829,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -768,19 +840,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -789,8 +867,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -800,8 +877,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -811,18 +888,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -831,8 +915,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -842,8 +925,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -853,18 +936,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -873,8 +963,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -884,8 +973,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -895,18 +984,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -915,8 +1011,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -926,8 +1021,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -937,18 +1032,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -957,8 +1059,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -968,8 +1069,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -979,18 +1080,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -998,9 +1106,10 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" }, @@ -1010,8 +1119,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1021,18 +1130,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1040,33 +1156,21 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-inference" - } - } - }, - "inference_graviton": { - "processors": [ - "cpu" - ], - "version_aliases": { - "1.12": "1.12.1", - "2.0": "2.0.1", - "2.1": "2.1.0", - "2.2": "2.2.1" - }, - "versions": { - "1.12.1": { + }, + "2.3.0": { "py_versions": [ - "py38" + "py311" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1076,42 +1180,43 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "repository": "pytorch-inference-graviton", - "container_version": { - "cpu": "ubuntu20.04" - } + "repository": "pytorch-inference" }, - "2.0.0": { + "2.4.0": { "py_versions": [ - "py310" + "py311" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1121,36 +1226,43 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "repository": "pytorch-inference-graviton", - "container_version": { - "cpu": "ubuntu20.04" - } + "repository": "pytorch-inference" }, - "2.0.1": { + "2.5.1": { "py_versions": [ - "py310" + "py311" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1160,36 +1272,43 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "repository": "pytorch-inference-graviton", - "container_version": { - "cpu": "ubuntu20.04" - } + "repository": "pytorch-inference" }, - "2.1.0": { + "2.6.0": { "py_versions": [ - "py310" + "py312" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1199,36 +1318,210 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "repository": "pytorch-inference-graviton", + "repository": "pytorch-inference" + } + } + }, + "inference_graviton": { + "processors": [ + "cpu" + ], + "version_aliases": { + "1.12": "1.12.1", + "2.0": "2.0.1", + "2.1": "2.1.0", + "2.2": "2.2.1", + "2.3": "2.3.0", + "2.4": "2.4.0" + }, + "versions": { + "1.12.1": { "container_version": { "cpu": "ubuntu20.04" - } - }, - "2.2.1": { + }, + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference-graviton" + }, + "2.0.0": { + "container_version": { + "cpu": "ubuntu20.04" + }, + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference-graviton" + }, + "2.0.1": { + "container_version": { + "cpu": "ubuntu20.04" + }, "py_versions": [ "py310" ], "registries": { "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference-graviton" + }, + "2.1.0": { + "container_version": { + "cpu": "ubuntu20.04" + }, + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1238,31 +1531,181 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference-graviton" + }, + "2.2.1": { + "container_version": { + "cpu": "ubuntu20.04" + }, + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", "eu-south-1": "692866216735", "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "repository": "pytorch-inference-graviton", + "repository": "pytorch-inference-graviton" + }, + "2.3.0": { "container_version": { "cpu": "ubuntu20.04" - } + }, + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference-graviton" + }, + "2.4.0": { + "container_version": { + "cpu": "ubuntu22.04" + }, + "py_versions": [ + "py311" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference-graviton" } } }, @@ -1289,7 +1732,12 @@ "1.13": "1.13.1", "2.0": "2.0.1", "2.1": "2.1.0", - "2.2": "2.2.0" + "2.2": "2.2.0", + "2.3": "2.3.0", + "2.4": "2.4.0", + "2.5": "2.5.1", + "2.6": "2.6.0", + "2.7": "2.7.1" }, "versions": { "0.4.0": { @@ -1392,17 +1840,263 @@ "us-west-1": "520713654638", "us-west-2": "520713654638" }, - "repository": "sagemaker-pytorch" + "repository": "sagemaker-pytorch" + }, + "1.2.0": { + "py_versions": [ + "py2", + "py3" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" + }, + "1.3.1": { + "py_versions": [ + "py2", + "py3" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" + }, + "1.4.0": { + "py_versions": [ + "py2", + "py3", + "py36" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" + }, + "1.5.0": { + "py_versions": [ + "py3", + "py36" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" + }, + "1.6.0": { + "py_versions": [ + "py3", + "py36" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" }, - "1.2.0": { + "1.7.1": { "py_versions": [ - "py2", - "py3" + "py3", + "py36" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1412,19 +2106,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1433,20 +2133,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.3.1": { + "1.8.0": { "py_versions": [ - "py2", - "py3" + "py3", + "py36" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1456,19 +2155,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1477,20 +2182,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.4.0": { + "1.8.1": { "py_versions": [ - "py2", - "py3" + "py3", + "py36" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1500,19 +2204,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1521,19 +2231,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.5.0": { + "1.9.0": { "py_versions": [ - "py3" + "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1543,19 +2252,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1564,20 +2279,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.6.0": { + "1.9.1": { "py_versions": [ - "py3", - "py36" + "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1587,19 +2300,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1608,20 +2327,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.7.1": { + "1.10.0": { "py_versions": [ - "py3", - "py36" + "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1631,19 +2348,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1652,20 +2375,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.8.0": { + "1.10.2": { "py_versions": [ - "py3", - "py36" + "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1675,19 +2396,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1696,20 +2423,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.8.1": { + "1.11.0": { "py_versions": [ - "py3", - "py36" + "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1719,19 +2444,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1740,19 +2471,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.9.0": { + "1.12.0": { "py_versions": [ "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1762,19 +2492,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1783,19 +2519,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.9.1": { + "1.12.1": { "py_versions": [ "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1805,19 +2540,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1826,19 +2567,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.10.0": { + "1.13.1": { "py_versions": [ - "py38" + "py39" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1848,19 +2588,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1869,19 +2615,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.10.2": { + "2.0.0": { "py_versions": [ - "py38" + "py310" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1891,19 +2636,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1912,19 +2663,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.11.0": { + "2.0.1": { "py_versions": [ - "py38" + "py310" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1934,19 +2684,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1955,19 +2711,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.12.0": { + "2.1.0": { "py_versions": [ - "py38" + "py310" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1977,19 +2732,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1997,20 +2758,21 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.12.1": { + "2.2.0": { "py_versions": [ - "py38" + "py310" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2020,18 +2782,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2039,20 +2808,21 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "1.13.1": { + "2.3.0": { "py_versions": [ - "py39" + "py311" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2062,18 +2832,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2081,20 +2858,21 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "2.0.0": { + "2.4.0": { "py_versions": [ - "py310" + "py311" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2104,39 +2882,43 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "2.0.1": { + "2.5.1": { "py_versions": [ - "py310" + "py311" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2146,39 +2928,43 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "2.1.0": { + "2.6.0": { "py_versions": [ - "py310" + "py312" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2188,39 +2974,43 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" }, - "2.2.0": { + "2.7.1": { "py_versions": [ - "py310" + "py312" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2230,31 +3020,35 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "pytorch-training" } } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/randomcutforest.json b/src/sagemaker/image_uri_config/randomcutforest.json index 74ab6898cc..25d9dcf3e8 100644 --- a/src/sagemaker/image_uri_config/randomcutforest.json +++ b/src/sagemaker/image_uri_config/randomcutforest.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "174872318107" }, diff --git a/src/sagemaker/image_uri_config/sagemaker-base-python.json b/src/sagemaker/image_uri_config/sagemaker-base-python.json index b1eaccf204..cd64d73af1 100644 --- a/src/sagemaker/image_uri_config/sagemaker-base-python.json +++ b/src/sagemaker/image_uri_config/sagemaker-base-python.json @@ -4,6 +4,7 @@ "registries": { "af-south-1": "559312083959", "ap-east-1": "493642496378", + "ap-east-2": "938034419563", "ap-northeast-1": "102112518831", "ap-northeast-2": "806072073708", "ap-northeast-3": "792733760839", @@ -11,10 +12,14 @@ "ap-southeast-1": "492261229750", "ap-southeast-2": "452832661640", "ap-southeast-3": "276181064229", + "ap-southeast-5": "148761635175", + "ap-southeast-7": "528757812139", "ca-central-1": "310906938811", + "ca-west-1": "623308166672", "cn-north-1": "390048526115", "cn-northwest-1": "390780980154", "eu-central-1": "936697816551", + "eu-central-2": "569303640362", "eu-north-1": "243637512696", "eu-south-1": "592751261982", "eu-south-2": "127363102723", @@ -24,11 +29,14 @@ "il-central-1": "380164790875", "me-central-1": "103105715889", "me-south-1": "117516905037", + "mx-central-1": "396913743851", "sa-east-1": "782484402741", "us-east-1": "081325390199", "us-east-2": "429704687514", "us-gov-east-1": "107072934176", "us-gov-west-1": "107173498710", + "us-isof-east-1": "840123138293", + "us-isof-south-1": "883091641454", "us-west-1": "742091327244", "us-west-2": "236514542706" }, diff --git a/src/sagemaker/image_uri_config/sagemaker-distribution.json b/src/sagemaker/image_uri_config/sagemaker-distribution.json new file mode 100644 index 0000000000..9853eb01ae --- /dev/null +++ b/src/sagemaker/image_uri_config/sagemaker-distribution.json @@ -0,0 +1,37 @@ +{ + "processors": ["cpu", "gpu"], + "scope": ["inference"], + "version_aliases": { + "3.2": "3.2.0" + }, + "versions": { + "3.2.0": { + "registries": { + "us-east-1": "885854791233", + "us-east-2": "137914896644", + "us-west-1": "053634841547", + "us-west-2": "542918446943", + "af-south-1": "238384257742", + "ap-east-1": "523751269255", + "ap-south-1": "245090515133", + "ap-northeast-2": "064688005998", + "ap-southeast-1": "022667117163", + "ap-southeast-2": "648430277019", + "ap-northeast-1": "010972774902", + "ca-central-1": "481561238223", + "eu-central-1": "545423591354", + "eu-west-1": "819792524951", + "eu-west-2": "021081402939", + "eu-west-3": "856416204555", + "eu-north-1": "175620155138", + "eu-south-1": "810671768855", + "sa-east-1": "567556641782", + "ap-northeast-3": "564864627153", + "ap-southeast-3": "370607712162", + "me-south-1": "523774347010", + "me-central-1": "358593528301" + }, + "repository": "sagemaker-distribution-prod" + } + } +} diff --git a/src/sagemaker/image_uri_config/sagemaker-tritonserver.json b/src/sagemaker/image_uri_config/sagemaker-tritonserver.json index b2257ce803..91842ae713 100644 --- a/src/sagemaker/image_uri_config/sagemaker-tritonserver.json +++ b/src/sagemaker/image_uri_config/sagemaker-tritonserver.json @@ -1,107 +1,212 @@ { - "processors": [ - "cpu", - "gpu" - ], - "scope": [ - "inference" - ], - "versions": { - "24.03": { - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "sagemaker-tritonserver", - "tag_prefix": "24.03-py3" - }, - "24.01": { - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "sagemaker-tritonserver", - "tag_prefix": "24.01-py3" - }, - "23.12": { - "registries": { - "af-south-1": "626614931356", - "il-central-1": "780543022126", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" - }, - "repository": "sagemaker-tritonserver", - "tag_prefix": "23.12-py3" - } - } -} \ No newline at end of file + "processors": [ + "cpu", + "gpu" + ], + "scope": [ + "inference" + ], + "versions": { + "25.04": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "sagemaker-tritonserver", + "tag_prefix": "25.04-py3" + }, + "24.09": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "sagemaker-tritonserver", + "tag_prefix": "24.09-py3" + }, + "24.05": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "sagemaker-tritonserver", + "tag_prefix": "24.05-py3" + }, + "24.03": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "sagemaker-tritonserver", + "tag_prefix": "24.03-py3" + }, + "24.01": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "sagemaker-tritonserver", + "tag_prefix": "24.01-py3" + }, + "23.12": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "sagemaker-tritonserver", + "tag_prefix": "23.12-py3" + } + } +} diff --git a/src/sagemaker/image_uri_config/semantic-segmentation.json b/src/sagemaker/image_uri_config/semantic-segmentation.json index e6e2b4350b..83f3e35f11 100644 --- a/src/sagemaker/image_uri_config/semantic-segmentation.json +++ b/src/sagemaker/image_uri_config/semantic-segmentation.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "433757028032" }, diff --git a/src/sagemaker/image_uri_config/seq2seq.json b/src/sagemaker/image_uri_config/seq2seq.json index 143f966a99..673b525468 100644 --- a/src/sagemaker/image_uri_config/seq2seq.json +++ b/src/sagemaker/image_uri_config/seq2seq.json @@ -39,6 +39,8 @@ "us-gov-west-1": "226302683700", "us-iso-east-1": "490574956308", "us-isob-east-1": "765400339828", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "632365934929", "us-west-2": "433757028032" }, diff --git a/src/sagemaker/image_uri_config/sklearn.json b/src/sagemaker/image_uri_config/sklearn.json index 656758d607..85114a11d2 100644 --- a/src/sagemaker/image_uri_config/sklearn.json +++ b/src/sagemaker/image_uri_config/sklearn.json @@ -42,6 +42,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -88,6 +90,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -134,6 +138,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -180,6 +186,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -230,6 +238,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -276,6 +286,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -322,6 +334,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -368,6 +382,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -418,6 +434,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, diff --git a/src/sagemaker/image_uri_config/spark.json b/src/sagemaker/image_uri_config/spark.json index 9a33ca87d9..0a430ebc77 100644 --- a/src/sagemaker/image_uri_config/spark.json +++ b/src/sagemaker/image_uri_config/spark.json @@ -11,6 +11,7 @@ "registries": { "af-south-1": "309385258863", "ap-east-1": "732049463269", + "ap-east-2": "533267296287", "ap-northeast-1": "411782140378", "ap-northeast-2": "860869212795", "ap-northeast-3": "102471314380", @@ -20,6 +21,8 @@ "ap-southeast-2": "440695851116", "ap-southeast-3": "800295151634", "ap-southeast-4": "819679513684", + "ap-southeast-5": "841784149062", + "ap-southeast-7": "471112967968", "ca-central-1": "446299261295", "ca-west-1": "000907499111", "cn-north-1": "671472414489", @@ -35,6 +38,7 @@ "il-central-1": "408426139102", "me-central-1": "395420993607", "me-south-1": "750251592176", + "mx-central-1": "211125459255", "sa-east-1": "737130764395", "us-east-1": "173754725891", "us-east-2": "314815235551", @@ -52,6 +56,7 @@ "registries": { "af-south-1": "309385258863", "ap-east-1": "732049463269", + "ap-east-2": "533267296287", "ap-northeast-1": "411782140378", "ap-northeast-2": "860869212795", "ap-northeast-3": "102471314380", @@ -61,6 +66,8 @@ "ap-southeast-2": "440695851116", "ap-southeast-3": "800295151634", "ap-southeast-4": "819679513684", + "ap-southeast-5": "841784149062", + "ap-southeast-7": "471112967968", "ca-central-1": "446299261295", "ca-west-1": "000907499111", "cn-north-1": "671472414489", @@ -76,6 +83,7 @@ "il-central-1": "408426139102", "me-central-1": "395420993607", "me-south-1": "750251592176", + "mx-central-1": "211125459255", "sa-east-1": "737130764395", "us-east-1": "173754725891", "us-east-2": "314815235551", @@ -93,6 +101,7 @@ "registries": { "af-south-1": "309385258863", "ap-east-1": "732049463269", + "ap-east-2": "533267296287", "ap-northeast-1": "411782140378", "ap-northeast-2": "860869212795", "ap-northeast-3": "102471314380", @@ -102,6 +111,8 @@ "ap-southeast-2": "440695851116", "ap-southeast-3": "800295151634", "ap-southeast-4": "819679513684", + "ap-southeast-5": "841784149062", + "ap-southeast-7": "471112967968", "ca-central-1": "446299261295", "ca-west-1": "000907499111", "cn-north-1": "671472414489", @@ -117,6 +128,7 @@ "il-central-1": "408426139102", "me-central-1": "395420993607", "me-south-1": "750251592176", + "mx-central-1": "211125459255", "sa-east-1": "737130764395", "us-east-1": "173754725891", "us-east-2": "314815235551", @@ -134,6 +146,7 @@ "registries": { "af-south-1": "309385258863", "ap-east-1": "732049463269", + "ap-east-2": "533267296287", "ap-northeast-1": "411782140378", "ap-northeast-2": "860869212795", "ap-northeast-3": "102471314380", @@ -143,6 +156,8 @@ "ap-southeast-2": "440695851116", "ap-southeast-3": "800295151634", "ap-southeast-4": "819679513684", + "ap-southeast-5": "841784149062", + "ap-southeast-7": "471112967968", "ca-central-1": "446299261295", "ca-west-1": "000907499111", "cn-north-1": "671472414489", @@ -158,6 +173,7 @@ "il-central-1": "408426139102", "me-central-1": "395420993607", "me-south-1": "750251592176", + "mx-central-1": "211125459255", "sa-east-1": "737130764395", "us-east-1": "173754725891", "us-east-2": "314815235551", @@ -175,6 +191,7 @@ "registries": { "af-south-1": "309385258863", "ap-east-1": "732049463269", + "ap-east-2": "533267296287", "ap-northeast-1": "411782140378", "ap-northeast-2": "860869212795", "ap-northeast-3": "102471314380", @@ -184,6 +201,8 @@ "ap-southeast-2": "440695851116", "ap-southeast-3": "800295151634", "ap-southeast-4": "819679513684", + "ap-southeast-5": "841784149062", + "ap-southeast-7": "471112967968", "ca-central-1": "446299261295", "ca-west-1": "000907499111", "cn-north-1": "671472414489", @@ -199,6 +218,53 @@ "il-central-1": "408426139102", "me-central-1": "395420993607", "me-south-1": "750251592176", + "mx-central-1": "211125459255", + "sa-east-1": "737130764395", + "us-east-1": "173754725891", + "us-east-2": "314815235551", + "us-gov-east-1": "260923028637", + "us-gov-west-1": "271483468897", + "us-west-1": "667973535471", + "us-west-2": "153931337802" + }, + "repository": "sagemaker-spark-processing" + }, + "3.5": { + "py_versions": [ + "py39", + "py312" + ], + "registries": { + "af-south-1": "309385258863", + "ap-east-1": "732049463269", + "ap-east-2": "533267296287", + "ap-northeast-1": "411782140378", + "ap-northeast-2": "860869212795", + "ap-northeast-3": "102471314380", + "ap-south-1": "105495057255", + "ap-south-2": "873151114052", + "ap-southeast-1": "759080221371", + "ap-southeast-2": "440695851116", + "ap-southeast-3": "800295151634", + "ap-southeast-4": "819679513684", + "ap-southeast-5": "841784149062", + "ap-southeast-7": "471112967968", + "ca-central-1": "446299261295", + "ca-west-1": "000907499111", + "cn-north-1": "671472414489", + "cn-northwest-1": "844356804704", + "eu-central-1": "906073651304", + "eu-central-2": "142351485170", + "eu-north-1": "330188676905", + "eu-south-1": "753923664805", + "eu-south-2": "833944533722", + "eu-west-1": "571004829621", + "eu-west-2": "836651553127", + "eu-west-3": "136845547031", + "il-central-1": "408426139102", + "me-central-1": "395420993607", + "me-south-1": "750251592176", + "mx-central-1": "211125459255", "sa-east-1": "737130764395", "us-east-1": "173754725891", "us-east-2": "314815235551", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 5dc8d35af2..f793edb4c9 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -140,7 +140,6 @@ "1.14.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", @@ -152,6 +151,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -162,8 +162,9 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -172,15 +173,13 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference-eia" }, "1.15.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", @@ -192,6 +191,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -202,8 +202,9 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -212,15 +213,13 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference-eia" }, "2.0.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", @@ -232,6 +231,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -242,8 +242,9 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -252,15 +253,13 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference-eia" }, "2.3.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", @@ -272,6 +271,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -282,8 +282,9 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -292,8 +293,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference-eia" } @@ -305,18 +305,18 @@ "gpu" ], "version_aliases": { - "1.10": "1.10.0", - "1.11": "1.11.0", - "1.12": "1.12.0", - "1.13": "1.13.0", - "1.14": "1.14.0", - "1.15": "1.15.5", "1.4": "1.4.1", "1.5": "1.5.0", "1.6": "1.6.0", "1.7": "1.7.0", "1.8": "1.8.0", "1.9": "1.9.0", + "1.10": "1.10.0", + "1.11": "1.11.0", + "1.12": "1.12.0", + "1.13": "1.13.0", + "1.14": "1.14.0", + "1.15": "1.15.5", "2.0": "2.0.4", "2.1": "2.1.3", "2.2": "2.2.2", @@ -331,9 +331,210 @@ "2.11": "2.11.1", "2.12": "2.12.1", "2.13": "2.13.0", - "2.14": "2.14.1" + "2.14": "2.14.1", + "2.16": "2.16.1", + "2.18": "2.18.0", + "2.19": "2.19.0" }, "versions": { + "1.4.1": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, + "1.5.0": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, + "1.6.0": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, + "1.7.0": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, + "1.8.0": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, + "1.9.0": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, "1.10.0": { "py_versions": [ "py2" @@ -430,8 +631,8 @@ "1.13.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -441,7 +642,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -452,8 +657,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -462,16 +669,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "1.14.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -481,7 +687,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -492,8 +702,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -502,16 +714,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "1.15.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -521,7 +732,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -532,8 +747,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -542,16 +759,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "1.15.2": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -561,7 +777,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -572,8 +792,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -582,16 +804,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "1.15.3": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -601,7 +822,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -612,8 +837,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -622,16 +849,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "1.15.4": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -641,7 +867,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -652,8 +882,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -662,16 +894,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "1.15.5": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -681,7 +912,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -692,8 +927,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -702,214 +939,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, - "1.4.1": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, - "1.5.0": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, - "1.6.0": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, - "1.7.0": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, - "1.8.0": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, - "1.9.0": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, "2.0.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -919,7 +957,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -930,8 +972,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -940,16 +984,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.0.1": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -959,7 +1002,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -970,8 +1017,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -980,16 +1029,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.0.2": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -999,7 +1047,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1010,8 +1062,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1020,16 +1074,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.0.3": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1039,7 +1092,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1050,8 +1107,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1060,16 +1119,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.0.4": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1079,7 +1137,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1090,8 +1152,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1100,16 +1164,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.1.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1119,7 +1182,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1130,8 +1197,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1140,16 +1209,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.1.1": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1159,7 +1227,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1170,8 +1242,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1180,16 +1254,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.1.2": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1199,7 +1272,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1210,8 +1287,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1220,16 +1299,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.1.3": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1239,7 +1317,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1250,8 +1332,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1260,16 +1344,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.2.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1279,7 +1362,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1290,8 +1377,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1300,16 +1389,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.2.1": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1319,7 +1407,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1330,8 +1422,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1340,16 +1434,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.2.2": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1359,7 +1452,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1370,8 +1467,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1380,16 +1479,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.3.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1399,7 +1497,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1410,8 +1512,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1420,16 +1524,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.3.1": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1439,7 +1542,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1450,8 +1557,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1460,16 +1569,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.3.2": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1479,7 +1587,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1490,8 +1602,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1500,16 +1614,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.4.1": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1519,7 +1632,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1530,8 +1647,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1540,16 +1659,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.4.3": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1559,7 +1677,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1570,8 +1692,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1580,16 +1704,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.5.1": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1599,7 +1722,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1610,8 +1737,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1620,16 +1749,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.6.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1639,7 +1767,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1650,8 +1782,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1660,16 +1794,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.6.3": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1679,7 +1812,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1690,8 +1827,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1700,16 +1839,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.7.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1719,7 +1857,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1730,8 +1872,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1740,16 +1884,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.8.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1759,7 +1902,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1770,8 +1917,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1780,16 +1929,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.8.4": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1799,7 +1947,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1810,8 +1962,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1820,16 +1974,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.9.2": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1839,7 +1992,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1850,7 +2007,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1859,16 +2019,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.9.3": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1878,7 +2037,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1889,7 +2052,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1898,16 +2064,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.10.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1917,7 +2082,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1928,7 +2097,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1937,16 +2109,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.10.1": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1956,7 +2127,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -1967,7 +2142,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1976,16 +2154,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.11.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -1995,7 +2172,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -2006,7 +2187,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2015,16 +2199,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.11.1": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2034,7 +2217,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -2045,7 +2232,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2054,16 +2244,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.12.1": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2073,7 +2262,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -2084,7 +2277,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2093,16 +2289,15 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, "2.13.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2112,7 +2307,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -2123,17 +2322,21 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-inference" }, @@ -2141,6 +2344,7 @@ "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2150,6 +2354,9 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -2165,6 +2372,7 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2172,32 +2380,18 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", "us-west-2": "763104351884" }, "repository": "tensorflow-inference" - } - } - }, - "inference_graviton": { - "processors": [ - "cpu" - ], - "version_aliases": { - "2.9": "2.9.1", - "2.12": "2.12.1", - "2.13": "2.13.0", - "2.14": "2.14.1" - }, - "versions": { - "2.9.1": { - "py_versions": [ - "py38" - ], + }, + "2.16.1": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2207,42 +2401,40 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "repository": "tensorflow-inference-graviton", - "container_version": { - "cpu": "ubuntu20.04" - } + "repository": "tensorflow-inference" }, - "2.12.1": { - "py_versions": [ - "py310" - ], + "2.18.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2252,42 +2444,40 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "repository": "tensorflow-inference-graviton", - "container_version": { - "cpu": "ubuntu20.04" - } + "repository": "tensorflow-inference" }, - "2.13.0": { - "py_versions": [ - "py310" - ], + "2.19.0": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2297,42 +2487,60 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "repository": "tensorflow-inference-graviton", + "repository": "tensorflow-inference" + } + } + }, + "inference_graviton": { + "processors": [ + "cpu" + ], + "version_aliases": { + "2.9": "2.9.1", + "2.12": "2.12.1", + "2.13": "2.13.0", + "2.14": "2.14.1", + "2.16": "2.16.1" + }, + "versions": { + "2.9.1": { "container_version": { "cpu": "ubuntu20.04" - } - }, - "2.14.1": { + }, "py_versions": [ - "py310" + "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2342,18 +2550,25 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-central-2": "380420809688", "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2362,34 +2577,232 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, - "repository": "tensorflow-inference-graviton", + "repository": "tensorflow-inference-graviton" + }, + "2.12.1": { "container_version": { "cpu": "ubuntu20.04" - } - } - } - }, - "training": { - "processors": [ - "cpu", - "gpu" - ], - "version_aliases": { - "1.10": "1.10.0", - "1.11": "1.11.0", - "1.12": "1.12.0", - "1.13": "1.13.1", - "1.14": "1.14.0", - "1.15": "1.15.5", + }, + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference-graviton" + }, + "2.13.0": { + "container_version": { + "cpu": "ubuntu20.04" + }, + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference-graviton" + }, + "2.14.1": { + "container_version": { + "cpu": "ubuntu20.04" + }, + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference-graviton" + }, + "2.16.1": { + "container_version": { + "cpu": "ubuntu20.04" + }, + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference-graviton" + } + } + }, + "training": { + "processors": [ + "cpu", + "gpu" + ], + "version_aliases": { "1.4": "1.4.1", "1.5": "1.5.0", "1.6": "1.6.0", "1.7": "1.7.0", "1.8": "1.8.0", "1.9": "1.9.0", + "1.10": "1.10.0", + "1.11": "1.11.0", + "1.12": "1.12.0", + "1.13": "1.13.1", + "1.14": "1.14.0", + "1.15": "1.15.5", "2.0": "2.0.4", "2.1": "2.1.3", "2.2": "2.2.2", @@ -2404,9 +2817,210 @@ "2.11": "2.11.0", "2.12": "2.12.0", "2.13": "2.13.0", - "2.14": "2.14.1" + "2.14": "2.14.1", + "2.16": "2.16.2", + "2.18": "2.18.0", + "2.19": "2.19.0" }, "versions": { + "1.4.1": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, + "1.5.0": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, + "1.6.0": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, + "1.7.0": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, + "1.8.0": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, + "1.9.0": { + "py_versions": [ + "py2" + ], + "registries": { + "af-south-1": "313743910680", + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-south-1": "048378556238", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-isob-east-1": "453391408702", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-tensorflow" + }, "1.10.0": { "py_versions": [ "py2" @@ -2542,7 +3156,6 @@ "py3": { "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", @@ -2554,6 +3167,7 @@ "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -2564,8 +3178,9 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2574,8 +3189,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" } @@ -2587,8 +3201,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2598,7 +3212,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -2609,8 +3227,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2619,8 +3239,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, @@ -2631,8 +3250,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2642,7 +3261,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -2653,8 +3276,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2663,8 +3288,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, @@ -2676,8 +3300,8 @@ ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2687,7 +3311,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -2698,8 +3326,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2708,8 +3338,7 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, @@ -2721,8 +3350,157 @@ ], "registries": { "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-training" + }, + "1.15.4": { + "py_versions": [ + "py3", + "py36", + "py37" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-training" + }, + "1.15.5": { + "py_versions": [ + "py3", + "py36", + "py37" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-east-2": "975050140332", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", + "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", "il-central-1": "780543022126", + "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-training" + }, + "2.0.0": { + "py_versions": [ + "py2", + "py3" + ], + "registries": { + "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2732,7 +3510,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -2743,8 +3525,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2753,21 +3537,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "1.15.4": { + "2.0.1": { "py_versions": [ - "py3", - "py36", - "py37" + "py2", + "py3" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2777,7 +3559,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -2788,8 +3574,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2798,21 +3586,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "1.15.5": { + "2.0.2": { "py_versions": [ - "py3", - "py36", - "py37" + "py2", + "py3" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -2822,7 +3608,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -2833,8 +3623,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2843,218 +3635,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "1.4.1": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, - "1.5.0": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, - "1.6.0": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, - "1.7.0": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, - "1.8.0": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, - "1.9.0": { - "py_versions": [ - "py2" - ], - "registries": { - "af-south-1": "313743910680", - "ap-east-1": "057415533634", - "ap-northeast-1": "520713654638", - "ap-northeast-2": "520713654638", - "ap-south-1": "520713654638", - "ap-southeast-1": "520713654638", - "ap-southeast-2": "520713654638", - "ca-central-1": "520713654638", - "cn-north-1": "422961961927", - "cn-northwest-1": "423003514399", - "eu-central-1": "520713654638", - "eu-north-1": "520713654638", - "eu-south-1": "048378556238", - "eu-west-1": "520713654638", - "eu-west-2": "520713654638", - "eu-west-3": "520713654638", - "me-south-1": "724002660598", - "sa-east-1": "520713654638", - "us-east-1": "520713654638", - "us-east-2": "520713654638", - "us-gov-west-1": "246785580436", - "us-iso-east-1": "744548109606", - "us-isob-east-1": "453391408702", - "us-west-1": "520713654638", - "us-west-2": "520713654638" - }, - "repository": "sagemaker-tensorflow" - }, - "2.0.0": { + "2.0.3": { "py_versions": [ - "py2", - "py3" + "py3", + "py36" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3064,7 +3657,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3075,8 +3672,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3085,20 +3684,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.0.1": { + "2.0.4": { "py_versions": [ - "py2", - "py3" + "py3", + "py36" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3108,7 +3706,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3119,8 +3721,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3129,20 +3733,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.0.2": { + "2.1.0": { "py_versions": [ "py2", "py3" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3152,7 +3755,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3163,8 +3770,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3173,19 +3782,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.0.3": { + "2.1.1": { "py_versions": [ + "py2", "py3" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3195,7 +3804,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3206,8 +3819,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3216,19 +3831,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.0.4": { + "2.1.2": { "py_versions": [ - "py3" + "py3", + "py36" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3238,7 +3853,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3249,8 +3868,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3259,20 +3880,19 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.1.0": { + "2.1.3": { "py_versions": [ - "py2", - "py3" + "py3", + "py36" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3282,7 +3902,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3293,8 +3917,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3303,20 +3929,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.1.1": { + "2.2.0": { "py_versions": [ - "py2", - "py3" + "py37" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3326,7 +3950,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3337,8 +3965,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3347,19 +3977,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.1.2": { + "2.2.1": { "py_versions": [ - "py3" + "py37" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3369,7 +3998,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3380,8 +4013,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3390,19 +4025,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.1.3": { + "2.2.2": { "py_versions": [ - "py3" + "py37" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3412,7 +4046,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3423,8 +4061,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3433,19 +4073,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.2.0": { + "2.3.0": { "py_versions": [ "py37" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3455,7 +4094,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3466,8 +4109,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3476,19 +4121,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.2.1": { + "2.3.1": { "py_versions": [ "py37" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3498,7 +4142,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3509,8 +4157,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3519,19 +4169,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.2.2": { + "2.3.2": { "py_versions": [ "py37" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3541,7 +4190,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3552,8 +4205,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3562,19 +4217,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.3.0": { + "2.4.1": { "py_versions": [ "py37" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3584,7 +4238,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3595,8 +4253,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3605,19 +4265,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.3.1": { + "2.4.3": { "py_versions": [ "py37" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3627,7 +4286,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3638,8 +4301,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3648,19 +4313,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.3.2": { + "2.5.0": { "py_versions": [ "py37" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3670,7 +4334,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3681,8 +4349,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3691,19 +4361,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.4.1": { + "2.5.1": { "py_versions": [ "py37" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3713,7 +4382,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3724,8 +4397,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3734,19 +4409,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.4.3": { + "2.6.0": { "py_versions": [ - "py37" + "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3756,7 +4430,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3767,8 +4445,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3777,19 +4457,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.5.0": { + "2.6.2": { "py_versions": [ - "py37" + "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3799,7 +4478,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3810,8 +4493,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3820,19 +4505,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.5.1": { + "2.6.3": { "py_versions": [ - "py37" + "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3842,7 +4526,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3853,8 +4541,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3863,19 +4553,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.6.0": { + "2.7.1": { "py_versions": [ "py38" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3885,7 +4574,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3896,8 +4589,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3906,19 +4601,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.6.2": { + "2.8.0": { "py_versions": [ - "py38" + "py39" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3928,7 +4622,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3939,8 +4637,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3949,19 +4649,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.6.3": { + "2.9.2": { "py_versions": [ - "py38" + "py39" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -3971,7 +4670,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -3982,8 +4685,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3992,19 +4697,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.7.1": { + "2.10.1": { "py_versions": [ - "py38" + "py39" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4014,7 +4718,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -4025,8 +4733,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4035,19 +4745,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.8.0": { + "2.11.0": { "py_versions": [ "py39" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4057,7 +4766,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -4068,8 +4781,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4078,19 +4793,18 @@ "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.9.2": { + "2.12.0": { "py_versions": [ - "py39" + "py310" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4100,7 +4814,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -4111,29 +4829,28 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", - "me-south-1": "217643126080", + "il-central-1": "780543022126", "me-central-1": "914824155844", + "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.10.1": { + "2.13.0": { "py_versions": [ - "py39" + "py310" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4143,7 +4860,11 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", @@ -4154,7 +4875,10 @@ "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4162,38 +4886,49 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.11.0": { + "2.14.1": { "py_versions": [ - "py39" + "py310" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -4201,91 +4936,113 @@ "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-isob-east-1": "094389454867", + "us-isof-east-1": "303241398832", + "us-isof-south-1": "454834333376", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.12.0": { + "2.16.2": { "py_versions": [ "py310" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.13.0": { + "2.18.0": { "py_versions": [ "py310" ], "registries": { "af-south-1": "626614931356", - "il-central-1": "780543022126", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", + "ca-west-1": "204538143572", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", + "il-central-1": "780543022126", + "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-west-1": "763104351884", - "us-west-2": "763104351884", - "ca-west-1": "204538143572" + "us-west-2": "763104351884" }, "repository": "tensorflow-training" }, - "2.14.1": { + "2.19.0": { "py_versions": [ - "py310" + "py312" ], "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", + "ap-east-2": "975050140332", "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", @@ -4295,6 +5052,9 @@ "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", "ap-southeast-4": "457447274322", + "ap-southeast-5": "550225433462", + "ap-southeast-6": "633930458069", + "ap-southeast-7": "590183813437", "ca-central-1": "763104351884", "ca-west-1": "204538143572", "cn-north-1": "727897471807", @@ -4310,13 +5070,12 @@ "il-central-1": "780543022126", "me-central-1": "914824155844", "me-south-1": "217643126080", + "mx-central-1": "637423239942", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-isob-east-1": "094389454867", "us-west-1": "763104351884", "us-west-2": "763104351884" }, @@ -4324,4 +5083,4 @@ } } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/xgboost-neo.json b/src/sagemaker/image_uri_config/xgboost-neo.json index 96ac466bdf..f727cb43d7 100644 --- a/src/sagemaker/image_uri_config/xgboost-neo.json +++ b/src/sagemaker/image_uri_config/xgboost-neo.json @@ -1,5 +1,7 @@ { - "scope": ["inference"], + "scope": [ + "inference" + ], "versions": { "latest": { "registries": { @@ -15,23 +17,27 @@ "cn-north-1": "472730292857", "cn-northwest-1": "474822919863", "eu-central-1": "746233611703", + "eu-central-2": "010526262399", "eu-north-1": "601324751636", "eu-south-1": "966458181534", "eu-west-1": "802834080501", "eu-west-2": "205493899709", "eu-west-3": "254080097072", + "il-central-1": "275950707576", "me-south-1": "836785723513", "sa-east-1": "756306329178", "us-east-1": "785573368785", "us-east-2": "007439368137", + "us-gov-east-1": "227234621604", "us-gov-west-1": "263933020539", "us-iso-east-1": "167761179201", "us-isob-east-1": "406031935815", + "us-isof-east-1": "751086301963", + "us-isof-south-1": "935523707064", "us-west-1": "710691900526", - "us-west-2": "301217895009", - "il-central-1": "275950707576" + "us-west-2": "301217895009" }, "repository": "xgboost-neo" } } -} +} \ No newline at end of file diff --git a/src/sagemaker/image_uri_config/xgboost.json b/src/sagemaker/image_uri_config/xgboost.json index 573a2db10e..88d621af49 100644 --- a/src/sagemaker/image_uri_config/xgboost.json +++ b/src/sagemaker/image_uri_config/xgboost.json @@ -83,6 +83,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -129,6 +131,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -175,6 +179,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -215,6 +221,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -255,6 +263,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -295,6 +305,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -335,6 +347,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -375,6 +389,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -466,6 +482,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -512,6 +530,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -558,6 +578,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -598,6 +620,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -638,6 +662,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -678,6 +704,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -718,6 +746,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -758,6 +788,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -802,6 +834,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, @@ -842,6 +876,8 @@ "us-gov-west-1": "414596584902", "us-iso-east-1": "833128469047", "us-isob-east-1": "281123927165", + "us-isof-east-1": "108575199400", + "us-isof-south-1": "124985052026", "us-west-1": "746614075791", "us-west-2": "246618743249" }, diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 143ecc9bdb..de6d622f78 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -21,7 +21,8 @@ from packaging.version import Version from sagemaker import utils -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import is_jumpstart_model_input from sagemaker.spark import defaults from sagemaker.jumpstart import artifacts @@ -37,6 +38,8 @@ ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}" HUGGING_FACE_FRAMEWORK = "huggingface" HUGGING_FACE_LLM_FRAMEWORK = "huggingface-llm" +HUGGING_FACE_TEI_GPU_FRAMEWORK = "huggingface-tei" +HUGGING_FACE_TEI_CPU_FRAMEWORK = "huggingface-tei-cpu" HUGGING_FACE_LLM_NEURONX_FRAMEWORK = "huggingface-llm-neuronx" XGBOOST_FRAMEWORK = "xgboost" SKLEARN_FRAMEWORK = "sklearn" @@ -62,12 +65,15 @@ def retrieve( training_compiler_config=None, model_id=None, model_version=None, + hub_arn=None, tolerate_vulnerable_model=False, tolerate_deprecated_model=False, sdk_version=None, inference_tool=None, serverless_inference_config=None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name=None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -95,6 +101,8 @@ def retrieve( https://github.com/aws/deep-learning-containers/blob/master/available_images.md (default: None). distribution (dict): A dictionary with information on how to run distributed training + base_framework_version (str): The base version number of PyTorch or Tensorflow. + (default: None). training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): A configuration class for the SageMaker Training Compiler (default: None). @@ -102,6 +110,8 @@ def retrieve( (default: None). model_version (str): The version of the JumpStart model for which to retrieve the image URI (default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications should be tolerated without an exception raised. If ``False``, raises an exception if the script used by this version of the model has dependencies with known security @@ -121,6 +131,9 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: The ECR URI for the corresponding SageMaker Docker image. @@ -143,23 +156,30 @@ def retrieve( ) if is_jumpstart_model_input(model_id, model_version): + if non_none_fields := { + key: value + for key, value in args.items() + if key in {"version", "framework", "container_version", "py_version"} + and value is not None + }: + JUMPSTART_LOGGER.info( + "Ignoring the following arguments when retrieving image uri " + "for JumpStart model id '%s': %s", + model_id, + str(non_none_fields), + ) return artifacts._retrieve_image_uri( - model_id, - model_version, - image_scope, - framework, - region, - version, - py_version, - instance_type, - accelerator_type, - container_version, - distribution, - base_framework_version, - training_compiler_config, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + image_scope=image_scope, + hub_arn=hub_arn, + region=region, + instance_type=instance_type, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): @@ -178,7 +198,7 @@ def retrieve( config = _config_for_framework_and_scope(_framework, final_image_scope, accelerator_type) original_version = version - version = _validate_version_and_set_if_needed(version, config, framework) + version = _validate_version_and_set_if_needed(version, config, framework, image_scope) version_config = config["versions"][_version_for_config(version, config)] if framework == HUGGING_FACE_FRAMEWORK: @@ -210,7 +230,7 @@ def retrieve( container_version = version_config["container_version"][processor] # Append sdk version in case of trainium instances - if repo in ["pytorch-training-neuron"]: + if repo in ["pytorch-training-neuron", "pytorch-training-neuronx"]: if not sdk_version: sdk_version = _get_latest_versions(version_config["sdk_versions"]) container_version = sdk_version + "-" + container_version @@ -449,6 +469,23 @@ def _get_latest_versions(list_of_versions): return sorted(list_of_versions, reverse=True)[0] +def _get_latest_version(framework, version, image_scope): + """Get the latest version from the input framework""" + if version: + return version + try: + framework_config = config_for_framework(framework) + except FileNotFoundError: + raise ValueError("Invalid framework {}".format(framework)) + + if not framework_config: + raise ValueError("Invalid framework {}".format(framework)) + + if not version: + version = _fetch_latest_version_from_config(framework_config, image_scope) + return version + + def _validate_accelerator_type(accelerator_type): """Raises a ``ValueError`` if ``accelerator_type`` is invalid.""" if not accelerator_type.startswith("ml.eia") and accelerator_type != "local_sagemaker_notebook": @@ -458,30 +495,16 @@ def _validate_accelerator_type(accelerator_type): ) -def _validate_version_and_set_if_needed(version, config, framework): +def _validate_version_and_set_if_needed(version, config, framework, image_scope): """Checks if the framework/algorithm version is one of the supported versions.""" + if not config: + config = config_for_framework(framework) available_versions = list(config["versions"].keys()) aliased_versions = list(config.get("version_aliases", {}).keys()) - if len(available_versions) == 1 and version not in aliased_versions: - log_message = "Defaulting to the only supported framework/algorithm version: {}.".format( - available_versions[0] - ) - if version and version != available_versions[0]: - logger.warning("%s Ignoring framework/algorithm version: %s.", log_message, version) - elif not version: - logger.info(log_message) - return available_versions[0] - - if version is None and framework in [ - DATA_WRANGLER_FRAMEWORK, - HUGGING_FACE_LLM_FRAMEWORK, - HUGGING_FACE_LLM_NEURONX_FRAMEWORK, - STABILITYAI_FRAMEWORK, - ]: - version = _get_latest_versions(available_versions) - + if not version: + version = _get_latest_version(framework, version, image_scope) _validate_arg(version, available_versions + aliased_versions, "{} version".format(framework)) return version @@ -678,10 +701,16 @@ def get_training_image_uri( if "modelparallel" in distribution["smdistributed"]: if distribution["smdistributed"]["modelparallel"].get("enabled", True): framework = "pytorch-smp" - if ( - "p5" in instance_type - or "2.1" in framework_version - or "2.2" in framework_version + supported_smp_pt_versions_cu124 = ("2.5",) + supported_smp_pt_versions_cu121 = ("2.1", "2.2", "2.3", "2.4") + if any( + pt_version in framework_version + for pt_version in supported_smp_pt_versions_cu124 + ): + container_version = "cu124" + elif "p5" in instance_type or any( + pt_version in framework_version + for pt_version in supported_smp_pt_versions_cu121 ): container_version = "cu121" else: @@ -728,3 +757,55 @@ def get_base_python_image_uri(region, py_version="310") -> str: repo_and_tag = repo + ":" + version return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo_and_tag) + + +def _fetch_latest_version_from_config( # pylint: disable=R0911 + framework_config: dict, image_scope: Optional[str] = None +) -> Optional[str]: + """Helper function to fetch the latest version as a string from a framework's config + + Args: + framework_config (dict): A framework config dict. + image_scope (str): Scope of the image, eg: training, inference + Returns: + Version string if latest version found else None + """ + if image_scope in framework_config: + if image_scope_config := framework_config[image_scope]: + if "version_aliases" in image_scope_config: + if "latest" in image_scope_config["version_aliases"]: + return image_scope_config["version_aliases"]["latest"] + top_version = None + bottom_version = None + + if "versions" in framework_config: + versions = list(framework_config["versions"].keys()) + if len(versions) == 1: + return versions[0] + top_version = versions[0] + bottom_version = versions[-1] + if top_version == "latest" or bottom_version == "latest": + return None + elif ( + image_scope is not None + and image_scope in framework_config + and "versions" in framework_config[image_scope] + ): + versions = list(framework_config[image_scope]["versions"].keys()) + top_version = versions[0] + bottom_version = versions[-1] + elif "processing" in framework_config and "versions" in framework_config["processing"]: + versions = list(framework_config["processing"]["versions"].keys()) + top_version = versions[0] + bottom_version = versions[-1] + if top_version and bottom_version: + if top_version.endswith(".x") or bottom_version.endswith(".x"): + top_number = int(top_version[:-2]) + bottom_number = int(bottom_version[:-2]) + max_version = max(top_number, bottom_number) + return f"{max_version}.x" + if Version(top_version) >= Version(bottom_version): + return top_version + return bottom_version + + return None diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 89779bef44..71678021d4 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -43,6 +43,8 @@ def __init__( attribute_names: Optional[List[Union[str, PipelineVariable]]] = None, target_attribute_name: Optional[Union[str, PipelineVariable]] = None, shuffle_config: Optional["ShuffleConfig"] = None, + hub_access_config: Optional[dict] = None, + model_access_config: Optional[dict] = None, ): r"""Create a definition for input data used by an SageMaker training job. @@ -102,6 +104,13 @@ def __init__( shuffle_config (sagemaker.inputs.ShuffleConfig): If specified this configuration enables shuffling on this channel. See the SageMaker API documentation for more info: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html + hub_access_config (dict): Specify the HubAccessConfig of a + Model Reference for which a training job is being created for. + model_access_config (dict): For models that require a Model Access Config, specify True + or False for to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). """ self.config = { "DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}} @@ -129,6 +138,27 @@ def __init__( self.config["TargetAttributeName"] = target_attribute_name if shuffle_config is not None: self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed} + self.add_hub_access_config(hub_access_config) + self.add_model_access_config(model_access_config) + + def add_hub_access_config(self, hub_access_config=None): + """Add Hub Access Config to the channel's configuration. + + Args: + hub_access_config (dict): The HubAccessConfig to be added to the + channel's configuration. + """ + if hub_access_config is not None: + self.config["DataSource"]["S3DataSource"]["HubAccessConfig"] = hub_access_config + + def add_model_access_config(self, model_access_config=None): + """Add Model Access Config to the channel's configuration. + + Args: + model_access_config (dict): Whether model terms of use have been accepted. + """ + if model_access_config is not None: + self.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] = model_access_config class ShuffleConfig(object): diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 48aaab0ac8..1b664fc9ae 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -30,12 +30,14 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -46,6 +48,8 @@ def retrieve_default( retrieve the default instance type. (Default: None). model_version (str): The version of the model for which to retrieve the default instance type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -64,6 +68,7 @@ def retrieve_default( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default instance type to use for the model. @@ -82,12 +87,14 @@ def retrieve_default( model_id, model_version, scope, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, model_type=model_type, + config_name=config_name, ) @@ -95,6 +102,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -110,6 +118,8 @@ def retrieve( retrieve the supported instance types. (Default: None). model_version (str): The version of the model for which to retrieve the supported instance types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -142,12 +152,13 @@ def retrieve( raise ValueError("Must specify scope for instance types.") return artifacts._retrieve_instance_types( - model_id, - model_version, - scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + scope=scope, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, ) diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 7040c376ab..6917421c04 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -65,6 +65,7 @@ def stop(self): @staticmethod def _load_config(inputs, estimator, expand_role=True, validate_uri=True): """Placeholder docstring""" + model_access_config, hub_access_config = _Job._get_access_configs(estimator) input_config = _Job._format_inputs_to_input_config(inputs, validate_uri) role = ( estimator.sagemaker_session.expand_role(estimator.role) @@ -83,6 +84,8 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): estimator.volume_size, estimator.volume_kms_key, estimator.keep_alive_period_in_seconds, + estimator.training_plan, + estimator.instance_placement_config, ) stop_condition = _Job._prepare_stop_condition(estimator.max_run, estimator.max_wait) vpc_config = estimator.get_vpc_config() @@ -94,19 +97,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): validate_uri, content_type="application/x-sagemaker-model", input_mode="File", + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) if model_channel: input_config = [] if input_config is None else input_config input_config.append(model_channel) - if estimator.enable_network_isolation(): - code_channel = _Job._prepare_channel( - input_config, estimator.code_uri, estimator.code_channel_name, validate_uri - ) + code_channel = _Job._prepare_channel( + input_config, + estimator.code_uri, + estimator.code_channel_name, + validate_uri, + ) - if code_channel: - input_config = [] if input_config is None else input_config - input_config.append(code_channel) + if code_channel: + input_config = [] if input_config is None else input_config + input_config.append(code_channel) return { "input_config": input_config, @@ -117,6 +124,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): "vpc_config": vpc_config, } + @staticmethod + def _get_access_configs(estimator): + """Return access configs from estimator object. + + JumpStartEstimator uses access configs which need to be added to the model channel, + so they are passed down to the job level. + + Args: + estimator (EstimatorBase): estimator object with access config field if applicable + """ + model_access_config, hub_access_config = None, None + if hasattr(estimator, "model_access_config"): + model_access_config = estimator.model_access_config + if hasattr(estimator, "hub_access_config"): + hub_access_config = estimator.hub_access_config + return model_access_config, hub_access_config + @staticmethod def _format_inputs_to_input_config(inputs, validate_uri=True): """Placeholder docstring""" @@ -172,6 +196,8 @@ def _format_string_uri_input( input_mode=None, compression=None, target_attribute_name=None, + model_access_config=None, + hub_access_config=None, ): """Placeholder docstring""" s3_input_result = TrainingInput( @@ -180,6 +206,8 @@ def _format_string_uri_input( input_mode=input_mode, compression=compression, target_attribute_name=target_attribute_name, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"): return s3_input_result @@ -192,7 +220,11 @@ def _format_string_uri_input( ) if isinstance(uri_input, str): return s3_input_result - if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)): + if isinstance(uri_input, (file_input, FileSystemInput)): + return uri_input + if isinstance(uri_input, TrainingInput): + uri_input.add_hub_access_config(hub_access_config=hub_access_config) + uri_input.add_model_access_config(model_access_config=model_access_config) return uri_input if is_pipeline_variable(uri_input): return s3_input_result @@ -210,6 +242,8 @@ def _prepare_channel( validate_uri=True, content_type=None, input_mode=None, + model_access_config=None, + hub_access_config=None, ): """Placeholder docstring""" if not channel_uri: @@ -225,7 +259,12 @@ def _prepare_channel( raise ValueError("Duplicate channel {} not allowed.".format(channel_name)) channel_input = _Job._format_string_uri_input( - channel_uri, validate_uri, content_type, input_mode + channel_uri, + validate_uri, + content_type, + input_mode, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) channel = _Job._convert_input_to_channel(channel_name, channel_input) @@ -294,6 +333,8 @@ def _prepare_resource_config( volume_size, volume_kms_key, keep_alive_period_in_seconds, + training_plan, + instance_placement_config=None, ): """Placeholder docstring""" resource_config = { @@ -319,6 +360,10 @@ def _prepare_resource_config( ) resource_config["InstanceCount"] = instance_count resource_config["InstanceType"] = instance_type + if training_plan is not None: + resource_config["TrainingPlanArn"] = training_plan + if instance_placement_config is not None: + resource_config["InstancePlacementConfig"] = instance_placement_config return resource_config diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 35df030ddc..9ebc2880bc 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -10,17 +10,26 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import import functools +import logging from typing import Any, Dict, List, Optional import boto3 from sagemaker.deprecations import deprecated -from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs +from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs, HubContentType from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import cache +from sagemaker.jumpstart.hub.utils import ( + construct_hub_model_arn_from_inputs, + construct_hub_model_reference_arn_from_inputs, + generate_hub_arn_for_init_kwargs, +) from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.session import Session +from sagemaker.jumpstart import constants class SageMakerSettings(object): @@ -253,8 +262,10 @@ def get_model_specs( region: str, model_id: str, version: str, + hub_arn: Optional[str] = None, s3_client: Optional[boto3.client] = None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. @@ -270,10 +281,56 @@ def get_model_specs( if s3_client is not None: additional_kwargs.update({"s3_client": s3_client}) + if hub_arn: + additional_kwargs.update({"sagemaker_session": sagemaker_session}) + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs} ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + + # Users only input model id, not contentType, so first try to describe with ModelReference, then with Model + if hub_arn: + try: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_arn, region=region, session=sagemaker_session + ) + + hub_model_arn = construct_hub_model_reference_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + model_specs = JumpStartModelsAccessor._cache.get_hub_model_reference( + hub_model_reference_arn=hub_model_arn + ) + model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE) + return model_specs + + except Exception as ex: + logging.info( + "Received exeption while calling APIs for ContentType ModelReference, \ + retrying with ContentType Model: " + + str(ex) + ) + hub_model_arn = construct_hub_model_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + + # Failed to describe ModelReference, try with Model + try: + model_specs = JumpStartModelsAccessor._cache.get_hub_model( + hub_model_arn=hub_model_arn + ) + model_specs.set_hub_content_type(HubContentType.MODEL) + + return model_specs + except Exception as ex: + # Failed with both, throw a custom error message + raise RuntimeError( + f"Cannot get details for {model_id} in Hub {hub_arn}. \ + {model_id} does not exist as a Model or ModelReference: \n" + + str(ex) + ) + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, version_str=version, model_type=model_type ) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index c28c27ed4e..48775542e6 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -19,6 +19,7 @@ SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( @@ -32,6 +33,7 @@ def _retrieve_default_environment_variables( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -39,6 +41,8 @@ def _retrieve_default_environment_variables( sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Dict[str, str]: """Retrieves the inference environment variables for the model matching the given arguments. @@ -47,6 +51,8 @@ def _retrieve_default_environment_variables( retrieve the default environment variables. model_version (str): Version of the JumpStart model for which to retrieve the default environment variables. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default environment variables. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -68,6 +74,9 @@ def _retrieve_default_environment_variables( environment variables specific for the instance type. script (JumpStartScriptScope): The JumpStart script for which to retrieve environment variables. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: the inference environment variables to use for the model. """ @@ -79,11 +88,14 @@ def _retrieve_default_environment_variables( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) default_environment_variables: Dict[str, str] = {} @@ -116,11 +128,14 @@ def _retrieve_default_environment_variables( lambda instance_type: _retrieve_gated_model_uri_env_var_value( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, instance_type=instance_type, + config_name=config_name, + model_type=model_type, ) ) @@ -162,11 +177,14 @@ def _retrieve_default_environment_variables( def _retrieve_gated_model_uri_env_var_value( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Retrieves the gated model env var URI matching the given arguments. @@ -175,6 +193,8 @@ def _retrieve_gated_model_uri_env_var_value( retrieve the gated model env var URI. model_version (str): Version of the JumpStart model for which to retrieve the gated model env var URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve the gated model env var URI. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -190,7 +210,9 @@ def _retrieve_gated_model_uri_env_var_value( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get environment variables specific for the instance type. - + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: Optional[str]: the s3 URI to use for the environment variable, or None if the model does not have gated training artifacts. @@ -206,11 +228,14 @@ def _retrieve_gated_model_uri_env_var_value( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) s3_key: Optional[str] = ( @@ -221,4 +246,7 @@ def _retrieve_gated_model_uri_env_var_value( if s3_key is None: return None + if hub_arn: + return s3_key + return f"s3://{get_jumpstart_gated_content_bucket(region)}/{s3_key}" diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index d19530ecfb..4bfe1732be 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -17,6 +17,7 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, VariableScope, ) @@ -30,12 +31,15 @@ def _retrieve_default_hyperparameters( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ): """Retrieves the training hyperparameters for the model matching the given arguments. @@ -44,6 +48,8 @@ def _retrieve_default_hyperparameters( retrieve the default hyperparameters. model_version (str): Version of the JumpStart model for which to retrieve the default hyperparameters. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (str): Region for which to retrieve default hyperparameters. (Default: None). include_container_hyperparameters (bool): True if container hyperparameters @@ -66,6 +72,9 @@ def _retrieve_default_hyperparameters( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: the hyperparameters to use for the model. """ @@ -77,11 +86,14 @@ def _retrieve_default_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) default_hyperparameters: Dict[str, str] = {} diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 9d19d5e069..8bcb205baa 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -14,13 +14,12 @@ from __future__ import absolute_import from typing import Optional -from sagemaker import image_uris from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, - ModelFramework, ) from sagemaker.jumpstart.utils import ( get_region_fallback, @@ -33,19 +32,14 @@ def _retrieve_image_uri( model_id: str, model_version: str, image_scope: str, - framework: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, - version: Optional[str] = None, - py_version: Optional[str] = None, instance_type: Optional[str] = None, - accelerator_type: Optional[str] = None, - container_version: Optional[str] = None, - distribution: Optional[str] = None, - base_framework_version: Optional[str] = None, - training_compiler_config: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ): """Retrieves the container image URI for JumpStart models. @@ -57,33 +51,16 @@ def _retrieve_image_uri( model_id (str): JumpStart model ID for which to retrieve image URI. model_version (str): Version of the JumpStart model for which to retrieve the image URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). image_scope (str): The image type, i.e. what it is used for. Valid values: "training", "inference", "eia". If ``accelerator_type`` is set, ``image_scope`` is ignored. - framework (str): The name of the framework or algorithm. region (str): The AWS region. (Default: None). - version (str): The framework or algorithm version. This is required if there is - more than one supported version for the given framework or algorithm. - (Default: None). - py_version (str): The Python version. This is required if there is - more than one supported Python version for the given framework version. instance_type (str): The SageMaker instance type. For supported types, see https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if there are different images for different processor types. (Default: None). - accelerator_type (str): Elastic Inference accelerator type. For more, see - https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html. - (Default: None). - container_version (str): the version of docker image. - Ideally the value of parameter should be created inside the framework. - For custom use, see the list of supported container versions: - https://github.com/aws/deep-learning-containers/blob/master/available_images.md. - (Default: None). - distribution (dict): A dictionary with information on how to run distributed training. - (Default: None). - training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`): - A configuration class for the SageMaker Training Compiler. - (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -95,6 +72,9 @@ def _retrieve_image_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -111,11 +91,14 @@ def _retrieve_image_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=image_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) if image_scope == JumpStartScriptScope.INFERENCE: @@ -126,14 +109,16 @@ def _retrieve_image_uri( ) if image_uri is not None: return image_uri - ecr_specs = model_specs.hosting_ecr_specs - if ecr_specs is None: - raise ValueError( - f"No inference ECR configuration found for JumpStart model ID '{model_id}' " - f"with {instance_type} instance type in {region}. " - "Please try another instance type or region." - ) - elif image_scope == JumpStartScriptScope.TRAINING: + if hub_arn: + ecr_uri = model_specs.hosting_ecr_uri + return ecr_uri + + raise ValueError( + f"No inference ECR configuration found for JumpStart model ID '{model_id}' " + f"with {instance_type} instance type in {region}. " + "Please try another instance type or region." + ) + if image_scope == JumpStartScriptScope.TRAINING: training_instance_type_variants = model_specs.training_instance_type_variants if training_instance_type_variants: image_uri = training_instance_type_variants.get_image_uri( @@ -141,63 +126,14 @@ def _retrieve_image_uri( ) if image_uri is not None: return image_uri - ecr_specs = model_specs.training_ecr_specs - if ecr_specs is None: - raise ValueError( - f"No training ECR configuration found for JumpStart model ID '{model_id}' " - f"with {instance_type} instance type in {region}. " - "Please try another instance type or region." - ) - if framework is not None and framework != ecr_specs.framework: - raise ValueError( - f"Incorrect container framework '{framework}' for JumpStart model ID '{model_id}' " - f"and version '{model_version}'." - ) - - if version is not None and version != ecr_specs.framework_version: - raise ValueError( - f"Incorrect container framework version '{version}' for JumpStart model ID " - f"'{model_id}' and version '{model_version}'." - ) + if hub_arn: + ecr_uri = model_specs.training_ecr_uri + return ecr_uri - if py_version is not None and py_version != ecr_specs.py_version: raise ValueError( - f"Incorrect python version '{py_version}' for JumpStart model ID '{model_id}' " - f"and version '{model_version}'." + f"No training ECR configuration found for JumpStart model ID '{model_id}' " + f"with {instance_type} instance type in {region}. " + "Please try another instance type or region." ) - base_framework_version_override: Optional[str] = None - version_override: Optional[str] = None - if ecr_specs.framework == ModelFramework.HUGGINGFACE: - base_framework_version_override = ecr_specs.framework_version - version_override = ecr_specs.huggingface_transformers_version - - if image_scope == JumpStartScriptScope.TRAINING: - return image_uris.get_training_image_uri( - region=region, - framework=ecr_specs.framework, - framework_version=version_override or ecr_specs.framework_version, - py_version=ecr_specs.py_version, - image_uri=None, - distribution=None, - compiler_config=None, - tensorflow_version=None, - pytorch_version=base_framework_version_override or base_framework_version, - instance_type=instance_type, - ) - if base_framework_version_override is not None: - base_framework_version_override = f"pytorch{base_framework_version_override}" - - return image_uris.retrieve( - framework=ecr_specs.framework, - region=region, - version=version_override or ecr_specs.framework_version, - py_version=ecr_specs.py_version, - instance_type=instance_type, - accelerator_type=accelerator_type, - image_scope=image_scope, - container_version=container_version, - distribution=distribution, - base_framework_version=base_framework_version_override or base_framework_version, - training_compiler_config=training_compiler_config, - ) + raise ValueError(f"Invalid scope: {image_scope}") diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 1b3c6f4b29..f3b44524e7 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -17,6 +17,7 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( @@ -30,9 +31,12 @@ def _model_supports_incremental_training( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> bool: """Returns True if the model supports incremental training. @@ -43,6 +47,8 @@ def _model_supports_incremental_training( support status for incremental training. region (Optional[str]): Region for which to retrieve the support status for incremental training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -54,6 +60,9 @@ def _model_supports_incremental_training( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: bool: the support status for incremental training. """ @@ -65,11 +74,14 @@ def _model_supports_incremental_training( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) return model_specs.supports_incremental_training() diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index e7c9c5911d..25119266cf 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -34,12 +34,14 @@ def _retrieve_default_instance_type( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model. @@ -50,6 +52,8 @@ def _retrieve_default_instance_type( default instance type. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default instance type. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -68,6 +72,7 @@ def _retrieve_default_instance_type( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default instance type to use for the model or None. @@ -83,12 +88,14 @@ def _retrieve_default_instance_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: @@ -123,11 +130,13 @@ def _retrieve_instance_types( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, training_instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported instance types for the model. @@ -138,6 +147,8 @@ def _retrieve_instance_types( supported instance types. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve supported instance types. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -156,6 +167,7 @@ def _retrieve_instance_types( Optionally supply this to get a inference instance type conditioned on the training instance, to ensure compatability of training artifact to inference instance. (Default: None). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported instance types to use for the model or None. @@ -171,11 +183,13 @@ def _retrieve_instance_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: @@ -196,7 +210,7 @@ def _retrieve_instance_types( elif scope == JumpStartScriptScope.TRAINING: if training_instance_type is not None: - raise ValueError("Cannot use `training_instance_type` argument " "with training scope.") + raise ValueError("Cannot use `training_instance_type` argument with training scope.") instance_types = model_specs.supported_training_instance_types else: raise NotImplementedError( diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 9cd152b0bb..6f2f7f38b5 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -32,11 +32,13 @@ def _retrieve_model_init_kwargs( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Model`. @@ -45,6 +47,8 @@ def _retrieve_model_init_kwargs( retrieve the kwargs. model_version (str): Version of the JumpStart model for which to retrieve the kwargs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -58,6 +62,7 @@ def _retrieve_model_init_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. """ @@ -69,12 +74,14 @@ def _retrieve_model_init_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) kwargs = deepcopy(model_specs.model_kwargs) @@ -89,11 +96,13 @@ def _retrieve_model_deploy_kwargs( model_id: str, model_version: str, instance_type: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -104,6 +113,8 @@ def _retrieve_model_deploy_kwargs( kwargs. instance_type (str): Instance type of the hosting endpoint, to determine if volume size is supported. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -117,6 +128,7 @@ def _retrieve_model_deploy_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: dict: the kwargs to use for the use case. @@ -129,12 +141,14 @@ def _retrieve_model_deploy_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None: @@ -147,10 +161,13 @@ def _retrieve_estimator_init_kwargs( model_id: str, model_version: str, instance_type: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Estimator`. @@ -161,6 +178,8 @@ def _retrieve_estimator_init_kwargs( kwargs. instance_type (str): Instance type of the training job, to determine if volume size is supported. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -174,6 +193,9 @@ def _retrieve_estimator_init_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: the kwargs to use for the use case. """ @@ -185,11 +207,14 @@ def _retrieve_estimator_init_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) kwargs = deepcopy(model_specs.estimator_kwargs) @@ -206,10 +231,13 @@ def _retrieve_estimator_init_kwargs( def _retrieve_estimator_fit_kwargs( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> dict: """Retrieves kwargs for `Estimator.fit`. @@ -218,6 +246,8 @@ def _retrieve_estimator_fit_kwargs( retrieve the kwargs. model_version (str): Version of the JumpStart model for which to retrieve the kwargs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -231,6 +261,9 @@ def _retrieve_estimator_fit_kwargs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: dict: the kwargs to use for the use case. @@ -243,11 +276,14 @@ def _retrieve_estimator_fit_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) return model_specs.fit_kwargs diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 57f66155c7..d4a0386c08 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -18,6 +18,7 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( @@ -31,10 +32,13 @@ def _retrieve_default_training_metric_definitions( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model. @@ -45,6 +49,8 @@ def _retrieve_default_training_metric_definitions( default training metric definitions. region (Optional[str]): Region for which to retrieve default training metric definitions. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -58,6 +64,9 @@ def _retrieve_default_training_metric_definitions( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: list: the default training metric definitions to use for the model or None. """ @@ -69,11 +78,14 @@ def _retrieve_default_training_metric_definitions( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) default_metric_definitions = ( @@ -89,16 +101,17 @@ def _retrieve_default_training_metric_definitions( else [] ) - instance_specific_metric_name: str - for instance_specific_metric_definition in instance_specific_metric_definitions: - instance_specific_metric_name = instance_specific_metric_definition["Name"] - default_metric_definitions = list( - filter( - lambda metric_definition: metric_definition["Name"] - != instance_specific_metric_name, - default_metric_definitions, + if instance_specific_metric_definitions: + instance_specific_metric_name: str + for instance_specific_metric_definition in instance_specific_metric_definitions: + instance_specific_metric_name = instance_specific_metric_definition["Name"] + default_metric_definitions = list( + filter( + lambda metric_definition: metric_definition["Name"] + != instance_specific_metric_name, + default_metric_definitions, + ) ) - ) - default_metric_definitions.append(instance_specific_metric_definition) + default_metric_definitions.append(instance_specific_metric_definition) return default_metric_definitions diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index aa22351771..c3b967d83e 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -32,11 +32,13 @@ def _retrieve_model_package_arn( model_version: str, instance_type: Optional[str], region: Optional[str], + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -48,6 +50,8 @@ def _retrieve_model_package_arn( instance_type (Optional[str]): An instance type to optionally supply in order to get an arn specific for the instance type. region (Optional[str]): Region for which to retrieve the model package arn. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (Optional[str]): Scope for which to retrieve the model package arn. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -60,6 +64,7 @@ def _retrieve_model_package_arn( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the model package arn to use for the model or None. @@ -72,12 +77,14 @@ def _retrieve_model_package_arn( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: @@ -93,7 +100,10 @@ def _retrieve_model_package_arn( if instance_specific_arn is not None: return instance_specific_arn - if model_specs.hosting_model_package_arns is None: + if ( + model_specs.hosting_model_package_arns is None + or model_specs.hosting_model_package_arns == {} + ): return None regional_arn = model_specs.hosting_model_package_arns.get(region) @@ -114,10 +124,13 @@ def _retrieve_model_package_model_artifact_s3_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[str]: """Retrieves s3 artifact uri associated with model package. @@ -128,6 +141,8 @@ def _retrieve_model_package_model_artifact_s3_uri( model package artifact. region (Optional[str]): Region for which to retrieve the model package artifact. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (Optional[str]): Scope for which to retrieve the model package artifact. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -141,6 +156,9 @@ def _retrieve_model_package_model_artifact_s3_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: the model package artifact uri to use for the model or None. @@ -157,11 +175,14 @@ def _retrieve_model_package_model_artifact_s3_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) if model_specs.training_model_package_artifact_uris is None: diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 6bb2e576fc..c1ad9710f1 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -20,6 +20,7 @@ ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( @@ -28,6 +29,7 @@ get_region_fallback, verify_model_region_and_return_specs, ) +from sagemaker.s3_utils import is_s3_url from sagemaker.session import Session from sagemaker.jumpstart.types import JumpStartModelSpecs @@ -73,7 +75,7 @@ def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_ty def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str: """Returns instance specific training artifact key or default one as fallback.""" instance_specific_training_artifact_key: Optional[str] = ( - model_specs.training_instance_type_variants.get_instance_specific_artifact_key( + model_specs.training_instance_type_variants.get_instance_specific_training_artifact_key( instance_type=instance_type ) if instance_type @@ -89,12 +91,15 @@ def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_t def _retrieve_model_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -105,6 +110,8 @@ def _retrieve_model_uri( the model artifact S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -120,6 +127,10 @@ def _retrieve_model_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). + Returns: str: the model artifact S3 URI for the corresponding model. @@ -136,11 +147,14 @@ def _retrieve_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=model_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) model_artifact_key: str @@ -149,6 +163,9 @@ def _retrieve_model_uri( is_prepacked = not model_specs.use_inference_script_uri() + if hub_arn: + model_artifact_uri = model_specs.hosting_artifact_uri + return model_artifact_uri model_artifact_key = ( _retrieve_hosting_prepacked_artifact_key(model_specs, instance_type) if is_prepacked @@ -169,8 +186,8 @@ def _retrieve_model_uri( os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE) or default_jumpstart_bucket ) - - model_s3_uri = f"s3://{bucket}/{model_artifact_key}" + if not is_s3_url(model_artifact_key): + model_s3_uri = f"s3://{bucket}/{model_artifact_key}" return model_s3_uri @@ -179,9 +196,12 @@ def _model_supports_training_model_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> bool: """Returns True if the model supports training with model uri field. @@ -192,6 +212,8 @@ def _model_supports_training_model_uri( support status for model uri with training. region (Optional[str]): Region for which to retrieve the support status for model uri with training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -203,6 +225,9 @@ def _model_supports_training_model_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: bool: the support status for model uri with training. """ @@ -214,11 +239,14 @@ def _model_supports_training_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) return model_specs.use_training_model_artifact() diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3359e32732..c217495ede 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -33,10 +33,12 @@ def _retrieve_example_payloads( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Optional[Dict[str, JumpStartSerializablePayload]]: """Returns example payloads. @@ -47,6 +49,8 @@ def _retrieve_example_payloads( example payloads. region (Optional[str]): Region for which to retrieve the example payloads. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -58,6 +62,7 @@ def _retrieve_example_payloads( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: Optional[Dict[str, JumpStartSerializablePayload]]: dictionary mapping payload aliases to the serializable payload object. @@ -70,12 +75,14 @@ def _retrieve_example_payloads( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_payloads = model_specs.default_payloads diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 4f6dfe1fe3..352a4384f8 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -73,11 +73,13 @@ def _retrieve_deserializer_from_accept_type( def _retrieve_default_deserializer( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -86,6 +88,8 @@ def _retrieve_default_deserializer( retrieve the default deserializer. model_version (str): Version of the JumpStart model for which to retrieve the default deserializer. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default deserializer. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -98,6 +102,7 @@ def _retrieve_default_deserializer( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseDeserializer: the default deserializer to use for the model. @@ -106,11 +111,13 @@ def _retrieve_default_deserializer( default_accept_type = _retrieve_default_accept_type( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type)) @@ -119,11 +126,13 @@ def _retrieve_default_deserializer( def _retrieve_default_serializer( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -132,6 +141,8 @@ def _retrieve_default_serializer( retrieve the default serializer. model_version (str): Version of the JumpStart model for which to retrieve the default serializer. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default serializer. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -144,6 +155,7 @@ def _retrieve_default_serializer( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: BaseSerializer: the default serializer to use for the model. """ @@ -151,11 +163,13 @@ def _retrieve_default_serializer( default_content_type = _retrieve_default_content_type( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type)) @@ -164,11 +178,13 @@ def _retrieve_default_serializer( def _retrieve_deserializer_options( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -177,6 +193,8 @@ def _retrieve_deserializer_options( retrieve the supported deserializers. model_version (str): Version of the JumpStart model for which to retrieve the supported deserializers. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve deserializer options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -189,6 +207,7 @@ def _retrieve_deserializer_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[BaseDeserializer]: the supported deserializers to use for the model. """ @@ -196,11 +215,13 @@ def _retrieve_deserializer_options( supported_accept_types = _retrieve_supported_accept_types( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) seen_classes: Set[Type] = set() @@ -223,10 +244,12 @@ def _retrieve_deserializer_options( def _retrieve_serializer_options( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model. @@ -235,6 +258,8 @@ def _retrieve_serializer_options( retrieve the supported serializers. model_version (str): Version of the JumpStart model for which to retrieve the supported serializers. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve serializer options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -247,6 +272,7 @@ def _retrieve_serializer_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[BaseSerializer]: the supported serializers to use for the model. """ @@ -254,10 +280,12 @@ def _retrieve_serializer_options( supported_content_types = _retrieve_supported_content_types( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) seen_classes: Set[Type] = set() @@ -280,11 +308,13 @@ def _retrieve_serializer_options( def _retrieve_default_content_type( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieves the default content type for the model. @@ -293,6 +323,8 @@ def _retrieve_default_content_type( retrieve the default content type. model_version (str): Version of the JumpStart model for which to retrieve the default content type. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default content type. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -305,6 +337,7 @@ def _retrieve_default_content_type( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default content type to use for the model. """ @@ -316,12 +349,14 @@ def _retrieve_default_content_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_content_type = model_specs.predictor_specs.default_content_type @@ -331,11 +366,13 @@ def _retrieve_default_content_type( def _retrieve_default_accept_type( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> str: """Retrieves the default accept type for the model. @@ -344,6 +381,8 @@ def _retrieve_default_accept_type( retrieve the default accept type. model_version (str): Version of the JumpStart model for which to retrieve the default accept type. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default accept type. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -356,6 +395,7 @@ def _retrieve_default_accept_type( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: the default accept type to use for the model. """ @@ -367,12 +407,14 @@ def _retrieve_default_accept_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) default_accept_type = model_specs.predictor_specs.default_accept_type @@ -383,11 +425,13 @@ def _retrieve_default_accept_type( def _retrieve_supported_accept_types( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -396,6 +440,8 @@ def _retrieve_supported_accept_types( retrieve the supported accept types. model_version (str): Version of the JumpStart model for which to retrieve the supported accept types. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve accept type options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -408,6 +454,7 @@ def _retrieve_supported_accept_types( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported accept types to use for the model. """ @@ -419,12 +466,14 @@ def _retrieve_supported_accept_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) supported_accept_types = model_specs.predictor_specs.supported_accept_types @@ -435,11 +484,13 @@ def _retrieve_supported_accept_types( def _retrieve_supported_content_types( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> List[str]: """Retrieves the supported content types for the model. @@ -448,6 +499,8 @@ def _retrieve_supported_content_types( retrieve the supported content types. model_version (str): Version of the JumpStart model for which to retrieve the supported content types. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve content type options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -460,6 +513,7 @@ def _retrieve_supported_content_types( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: list: the supported content types to use for the model. """ @@ -471,12 +525,14 @@ def _retrieve_supported_content_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) supported_content_types = model_specs.predictor_specs.supported_content_types diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index cffd46d043..8c47750061 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -31,10 +31,13 @@ def _retrieve_resource_name_base( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + scope: JumpStartScriptScope = JumpStartScriptScope.INFERENCE, + config_name: Optional[str] = None, ) -> bool: """Returns default resource name. @@ -45,6 +48,8 @@ def _retrieve_resource_name_base( default resource name. region (Optional[str]): Region for which to retrieve the default resource name. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -56,6 +61,7 @@ def _retrieve_resource_name_base( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config. (Default: None). Returns: str: the default resource name. """ @@ -67,12 +73,14 @@ def _retrieve_resource_name_base( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, - scope=JumpStartScriptScope.INFERENCE, + hub_arn=hub_arn, + scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) return model_specs.resource_name_base diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 369acac85f..74523be1de 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -48,12 +48,14 @@ def _retrieve_default_resources( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model. @@ -64,6 +66,8 @@ def _retrieve_default_resources( default resource requirements. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default resource requirements. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -79,6 +83,7 @@ def _retrieve_default_resources( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get host requirements specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default resource requirements to use for the model or None. @@ -96,12 +101,14 @@ def _retrieve_default_resources( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, + config_name=config_name, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index f69732d2e0..e9b58debc3 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -19,6 +19,7 @@ ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE, ) from sagemaker.jumpstart.enums import ( + JumpStartModelType, JumpStartScriptScope, ) from sagemaker.jumpstart.utils import ( @@ -32,11 +33,14 @@ def _retrieve_script_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -47,6 +51,8 @@ def _retrieve_script_uri( retrieve the script S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. @@ -62,6 +68,9 @@ def _retrieve_script_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: the model script URI for the corresponding model. @@ -78,11 +87,14 @@ def _retrieve_script_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) if script_scope == JumpStartScriptScope.INFERENCE: @@ -104,10 +116,13 @@ def _retrieve_script_uri( def _model_supports_inference_script_uri( model_id: str, model_version: str, - region: Optional[str], + hub_arn: Optional[str] = None, + region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: Optional[str] = None, ) -> bool: """Returns True if the model supports inference with script uri field. @@ -116,6 +131,8 @@ def _model_supports_inference_script_uri( retrieve the support status for script uri with inference. model_version (str): Version of the JumpStart model for which to retrieve the support status for script uri with inference. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve the support status for script uri with inference. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -129,6 +146,8 @@ def _model_supports_inference_script_uri( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: bool: the support status for script uri with inference. """ @@ -140,11 +159,14 @@ def _model_supports_inference_script_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) return model_specs.use_inference_script_uri() diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index e9a34a21a8..5a4be3f53f 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -15,13 +15,14 @@ import datetime from difflib import get_close_matches import os -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import json import boto3 import botocore from packaging.version import Version from packaging.specifiers import SpecifierSet, InvalidSpecifier from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, @@ -42,16 +43,23 @@ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelHeader, JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, + HubContentType, +) +from sagemaker.jumpstart.hub import utils as hub_utils +from sagemaker.jumpstart.hub.interfaces import DescribeHubContentResponse +from sagemaker.jumpstart.hub.parsers import ( + make_model_specs_from_describe_hub_content_response, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache +from sagemaker.session import Session class JumpStartModelsCache: @@ -77,6 +85,7 @@ def __init__( s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, + sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> None: """Initialize a ``JumpStartModelsCache`` instance. @@ -98,13 +107,15 @@ def __init__( s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. Default: None (no config). s3_client (Optional[boto3.client]): s3 client to use. Default: None. + sagemaker_session: sagemaker session object to use. + Default: session object from default region us-west-2. """ self._region = region or utils.get_region_fallback( s3_bucket_name=s3_bucket_name, s3_client=s3_client ) - self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( + self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue]( max_cache_items=max_s3_cache_items, expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._retrieval_function, @@ -139,6 +150,8 @@ def __init__( if s3_client_config else boto3.client("s3", region_name=self._region) ) + # Fallback in case a caller overrides sagemaker_session to None + self._sagemaker_session = sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" @@ -230,8 +243,8 @@ def _model_id_retrieval_function( model_id, version = key.model_id, key.version sm_version = utils.get_sagemaker_version() - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -250,7 +263,7 @@ def _model_id_retrieval_function( return JumpStartVersionedModelId(model_id, sm_compatible_model_version) versions_incompatible_with_sagemaker = [ - Version(header.version) + header.version for header in manifest.values() # type: ignore if header.model_id == model_id ] @@ -281,7 +294,8 @@ def _model_id_retrieval_function( raise KeyError(error_msg) error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. " - error_msg += f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. " + error_msg += "Specify a different model ID or try a different AWS Region. " + error_msg += f"For a list of available models, see {MODEL_ID_LIST_WEB_URL}. " other_model_id_version = None if model_type == JumpStartModelType.OPEN_WEIGHTS: @@ -358,10 +372,18 @@ def _get_json_file( object and None when reading from the local file system. """ if self._is_local_metadata_mode(): - file_content, etag = self._get_json_file_from_local_override(key, filetype), None - else: - file_content, etag = self._get_json_file_and_etag_from_s3(key) - return file_content, etag + if filetype in { + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + JumpStartS3FileType.OPEN_WEIGHT_SPECS, + }: + return self._get_json_file_from_local_override(key, filetype), None + else: + JUMPSTART_LOGGER.warning( + "Local metadata mode is enabled, but the file type %s is not supported " + "for local override. Falling back to s3.", + filetype, + ) + return self._get_json_file_and_etag_from_s3(key) def _get_json_md5_hash(self, key: str): """Retrieves md5 object hash for s3 objects, using `s3.head_object`. @@ -392,53 +414,101 @@ def _get_json_file_from_local_override( def _retrieval_function( self, - key: JumpStartCachedS3ContentKey, - value: Optional[JumpStartCachedS3ContentValue], - ) -> JumpStartCachedS3ContentValue: - """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``. + key: JumpStartCachedContentKey, + value: Optional[JumpStartCachedContentValue], + ) -> JumpStartCachedContentValue: + """Return s3 content given a file type and s3_key in ``JumpStartCachedContentKey``. If a manifest file is being fetched, we only download the object if the md5 hash in ``head_object`` does not match the current md5 hash for the stored value. This prevents unnecessarily downloading the full manifest when it hasn't changed. Args: - key (JumpStartCachedS3ContentKey): key for which to fetch s3 content. + key (JumpStartCachedContentKey): key for which to fetch s3 content. value (Optional[JumpStartVersionedModelId]): Current value of old cached s3 content. This is used for the manifest file, so that it is only downloaded when its content changes. """ - file_type, s3_key = key.file_type, key.s3_key - if file_type in { + data_type, id_info = key.data_type, key.id_info + + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartS3FileType.PROPRIETARY_MANIFEST, }: if value is not None and not self._is_local_metadata_mode(): - etag = self._get_json_md5_hash(s3_key) + etag = self._get_json_md5_hash(id_info) if etag == value.md5_hash: return value - formatted_body, etag = self._get_json_file(s3_key, file_type) - return JumpStartCachedS3ContentValue( + formatted_body, etag = self._get_json_file(id_info, data_type) + return JumpStartCachedContentValue( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - if file_type in { + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_SPECS, JumpStartS3FileType.PROPRIETARY_SPECS, }: - formatted_body, _ = self._get_json_file(s3_key, file_type) + formatted_body, _ = self._get_json_file(id_info, data_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedS3ContentValue(formatted_content=model_specs) - raise ValueError(self._file_type_error_msg(file_type)) + return JumpStartCachedContentValue(formatted_content=model_specs) + + if data_type == HubContentType.NOTEBOOK: + hub_name, _, notebook_name, notebook_version = hub_utils.get_info_from_hub_resource_arn( + id_info + ) + response: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=notebook_name, + hub_content_version=notebook_version, + hub_content_type=data_type, + ) + hub_notebook_description = DescribeHubContentResponse(response) + return JumpStartCachedContentValue(formatted_content=hub_notebook_description) + + if data_type in { + HubContentType.MODEL, + HubContentType.MODEL_REFERENCE, + }: + + hub_resource_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn(id_info) + hub_arn = hub_utils.construct_hub_arn_from_name( + hub_name=hub_resource_arn_extracted_info.hub_name, + region=hub_resource_arn_extracted_info.region, + account_id=hub_resource_arn_extracted_info.account_id, + ) + + model_version: str = hub_utils.get_hub_model_version( + hub_model_name=hub_resource_arn_extracted_info.hub_content_name, + hub_model_type=data_type.value, + hub_name=hub_arn, + sagemaker_session=self._sagemaker_session, + hub_model_version=hub_resource_arn_extracted_info.hub_content_version, + ) + + hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=hub_arn, + hub_content_name=hub_resource_arn_extracted_info.hub_content_name, + hub_content_version=model_version, + hub_content_type=data_type.value, + ) + + model_specs = make_model_specs_from_describe_hub_content_response( + DescribeHubContentResponse(hub_model_description), + ) + + return JumpStartCachedContentValue(formatted_content=model_specs) + + raise ValueError(self._file_type_error_msg(data_type)) def get_manifest( self, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" - manifest_dict = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest_dict = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -479,9 +549,7 @@ def _select_version( """ if version_str == "*": - if len(available_versions) == 0: - return None - return str(max(available_versions)) + return utils.get_latest_version(available_versions) if model_type == JumpStartModelType.PROPRIETARY: if "*" in version_str: @@ -492,6 +560,12 @@ def _select_version( ) return version_str if version_str in available_versions else None + if version_str[-1] == "*": + # major or minor version is pinned, e.g 1.* or 1.0.* + return utils.get_latest_version( + [version for version in available_versions if version.startswith(version_str[:-1])] + ) + try: spec = SpecifierSet(f"=={version_str}") except InvalidSpecifier: @@ -525,8 +599,8 @@ def _get_header_impl( JumpStartVersionedModelId(model_id, semantic_version_str) )[0] - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -556,8 +630,8 @@ def get_specs( """ header = self.get_header(model_id, version_str, model_type) spec_key = header.spec_key - specs, cache_hit = self._s3_cache.get( - JumpStartCachedS3ContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) + specs, cache_hit = self._content_cache.get( + JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) ) if not cache_hit and "*" in version_str: @@ -566,8 +640,38 @@ def get_specs( ) return specs.formatted_content + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + details, _ = self._content_cache.get( + JumpStartCachedContentKey( + HubContentType.MODEL, + hub_model_arn, + ) + ) + return details.formatted_content + + def get_hub_model_reference(self, hub_model_reference_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model reference + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + details, _ = self._content_cache.get( + JumpStartCachedContentKey( + HubContentType.MODEL_REFERENCE, + hub_model_reference_arn, + ) + ) + return details.formatted_content + def clear(self) -> None: """Clears the model ID/version and s3 cache.""" - self._s3_cache.clear() + self._content_cache.clear() self._open_weight_model_id_manifest_key_cache.clear() self._proprietary_model_id_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 2c3c6040d9..b81f97ce3a 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -15,6 +15,7 @@ import logging import os from typing import Dict, Set, Type +import json import boto3 from sagemaker.base_deserializers import BaseDeserializer, JSONDeserializer from sagemaker.jumpstart.enums import ( @@ -35,135 +36,58 @@ from sagemaker.session import Session +JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") + +# disable logging if env var is set +JUMPSTART_LOGGER.addHandler( + type( + "", + (logging.StreamHandler,), + { + "emit": lambda self, *args, **kwargs: ( + logging.StreamHandler.emit(self, *args, **kwargs) + if not os.environ.get(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING) + else None + ) + }, + )() +) + + +_CURRENT_FILE_DIRECTORY_PATH = os.path.dirname(os.path.realpath(__file__)) +REGION_CONFIG_JSON_FILENAME = "region_config.json" +REGION_CONFIG_JSON_FILEPATH = os.path.join( + _CURRENT_FILE_DIRECTORY_PATH, REGION_CONFIG_JSON_FILENAME +) + + +def _load_region_config(filepath: str) -> Set[JumpStartLaunchedRegionInfo]: + """Load the JumpStart region config from a JSON file.""" + debug_msg = f"Loading JumpStart region config from '{filepath}'." + JUMPSTART_LOGGER.debug(debug_msg) + try: + with open(filepath) as f: + config = json.load(f) + + return { + JumpStartLaunchedRegionInfo( + region_name=region, + content_bucket=data["content_bucket"], + gated_content_bucket=data.get("gated_content_bucket"), + neo_content_bucket=data.get("neo_content_bucket"), + ) + for region, data in config.items() + } + except Exception: # pylint: disable=W0703 + JUMPSTART_LOGGER.error("Unable to load JumpStart region config.", exc_info=True) + return set() + + ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING = "DISABLE_JUMPSTART_LOGGING" +ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY = "DISABLE_JUMPSTART_TELEMETRY" -JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set( - [ - JumpStartLaunchedRegionInfo( - region_name="us-west-2", - content_bucket="jumpstart-cache-prod-us-west-2", - gated_content_bucket="jumpstart-private-cache-prod-us-west-2", - ), - JumpStartLaunchedRegionInfo( - region_name="us-east-1", - content_bucket="jumpstart-cache-prod-us-east-1", - gated_content_bucket="jumpstart-private-cache-prod-us-east-1", - ), - JumpStartLaunchedRegionInfo( - region_name="us-east-2", - content_bucket="jumpstart-cache-prod-us-east-2", - gated_content_bucket="jumpstart-private-cache-prod-us-east-2", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-west-1", - content_bucket="jumpstart-cache-prod-eu-west-1", - gated_content_bucket="jumpstart-private-cache-prod-eu-west-1", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-central-1", - content_bucket="jumpstart-cache-prod-eu-central-1", - gated_content_bucket="jumpstart-private-cache-prod-eu-central-1", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-north-1", - content_bucket="jumpstart-cache-prod-eu-north-1", - gated_content_bucket="jumpstart-private-cache-prod-eu-north-1", - ), - JumpStartLaunchedRegionInfo( - region_name="me-south-1", - content_bucket="jumpstart-cache-prod-me-south-1", - gated_content_bucket="jumpstart-private-cache-prod-me-south-1", - ), - JumpStartLaunchedRegionInfo( - region_name="me-central-1", - content_bucket="jumpstart-cache-prod-me-central-1", - gated_content_bucket="jumpstart-private-cache-prod-me-central-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-south-1", - content_bucket="jumpstart-cache-prod-ap-south-1", - gated_content_bucket="jumpstart-private-cache-prod-ap-south-1", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-west-3", - content_bucket="jumpstart-cache-prod-eu-west-3", - gated_content_bucket="jumpstart-private-cache-prod-eu-west-3", - ), - JumpStartLaunchedRegionInfo( - region_name="af-south-1", - content_bucket="jumpstart-cache-prod-af-south-1", - gated_content_bucket="jumpstart-private-cache-prod-af-south-1", - ), - JumpStartLaunchedRegionInfo( - region_name="sa-east-1", - content_bucket="jumpstart-cache-prod-sa-east-1", - gated_content_bucket="jumpstart-private-cache-prod-sa-east-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-east-1", - content_bucket="jumpstart-cache-prod-ap-east-1", - gated_content_bucket="jumpstart-private-cache-prod-ap-east-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-northeast-2", - content_bucket="jumpstart-cache-prod-ap-northeast-2", - gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-2", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-northeast-3", - content_bucket="jumpstart-cache-prod-ap-northeast-3", - gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-3", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-southeast-3", - content_bucket="jumpstart-cache-prod-ap-southeast-3", - gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-3", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-west-2", - content_bucket="jumpstart-cache-prod-eu-west-2", - gated_content_bucket="jumpstart-private-cache-prod-eu-west-2", - ), - JumpStartLaunchedRegionInfo( - region_name="eu-south-1", - content_bucket="jumpstart-cache-prod-eu-south-1", - gated_content_bucket="jumpstart-private-cache-prod-eu-south-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-northeast-1", - content_bucket="jumpstart-cache-prod-ap-northeast-1", - gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-1", - ), - JumpStartLaunchedRegionInfo( - region_name="us-west-1", - content_bucket="jumpstart-cache-prod-us-west-1", - gated_content_bucket="jumpstart-private-cache-prod-us-west-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-southeast-1", - content_bucket="jumpstart-cache-prod-ap-southeast-1", - gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-1", - ), - JumpStartLaunchedRegionInfo( - region_name="ap-southeast-2", - content_bucket="jumpstart-cache-prod-ap-southeast-2", - gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-2", - ), - JumpStartLaunchedRegionInfo( - region_name="ca-central-1", - content_bucket="jumpstart-cache-prod-ca-central-1", - gated_content_bucket="jumpstart-private-cache-prod-ca-central-1", - ), - JumpStartLaunchedRegionInfo( - region_name="cn-north-1", - content_bucket="jumpstart-cache-prod-cn-north-1", - ), - JumpStartLaunchedRegionInfo( - region_name="il-central-1", - content_bucket="jumpstart-cache-prod-il-central-1", - gated_content_bucket="jumpstart-private-cache-prod-il-central-1", - ), - ] +JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = _load_region_config( + REGION_CONFIG_JSON_FILEPATH ) JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT = { @@ -183,10 +107,16 @@ ) JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" +NEO_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" + +JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub" JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" +HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" +HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" + INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py" @@ -200,6 +130,7 @@ "AWS_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE" ) ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE = "AWS_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE" +ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE = "AWS_NEO_CONTENT_BUCKET_OVERRIDE" JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart" @@ -245,23 +176,6 @@ MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html" -JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") - -# disable logging if env var is set -JUMPSTART_LOGGER.addHandler( - type( - "", - (logging.StreamHandler,), - { - "emit": lambda self, *args, **kwargs: ( - logging.StreamHandler.emit(self, *args, **kwargs) - if not os.environ.get(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING) - else None - ) - }, - )() -) - try: DEFAULT_JUMPSTART_SAGEMAKER_SESSION = Session( boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME) diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index ca49fd41a3..91f547afb6 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -15,6 +15,7 @@ from __future__ import absolute_import from enum import Enum +from typing import List class ModelFramework(str, Enum): @@ -81,6 +82,12 @@ class VariableTypes(str, Enum): BOOL = "bool" +class HubContentCapability(str, Enum): + """Enum class for HubContent capabilities.""" + + BEDROCK_CONSOLE = "BEDROCK_CONSOLE" + + class JumpStartTag(str, Enum): """Enum class for tag keys to apply to JumpStart models.""" @@ -93,6 +100,13 @@ class JumpStartTag(str, Enum): MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" + INFERENCE_CONFIG_NAME = "sagemaker-sdk:jumpstart-inference-config-name" + TRAINING_CONFIG_NAME = "sagemaker-sdk:jumpstart-training-config-name" + + HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn" + + BEDROCK = "sagemaker-sdk:bedrock" + class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" @@ -126,6 +140,28 @@ def from_suffixed_type(mime_type_with_suffix: str) -> "MIMEType": return MIMEType(base_type) +class NamingConventionType(str, Enum): + """Enum class for naming conventions.""" + + SNAKE_CASE = "snake_case" + UPPER_CAMEL_CASE = "upper_camel_case" + DEFAULT = UPPER_CAMEL_CASE + + +class ModelSpecKwargType(str, Enum): + """Enum class for types of kwargs for model hub content document and model specs.""" + + FIT = "fit_kwargs" + MODEL = "model_kwargs" + ESTIMATOR = "estimator_kwargs" + DEPLOY = "deploy_kwargs" + + @classmethod + def arg_keys(cls) -> List[str]: + """Returns a list of kwargs keys that each type can have""" + return [member.value for member in cls] + + class JumpStartConfigRankingName(str, Enum): """Enum class for ranking of JumpStart config.""" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index bade834cc6..e61e1c49a5 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -14,7 +14,7 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union from sagemaker import session from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -28,16 +28,22 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG from sagemaker.jumpstart.factory.estimator import get_deploy_kwargs, get_fit_kwargs, get_init_kwargs from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job +from sagemaker.jumpstart.session_utils import get_model_info_from_training_job +from sagemaker.jumpstart.types import JumpStartMetadataConfig from sagemaker.jumpstart.utils import ( + get_jumpstart_configs, validate_model_id_and_get_type, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, + remove_env_var_from_estimator_kwargs_if_model_access_config_present, + get_model_access_config, + get_hub_access_config, ) from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model_monitor.data_capture_config import DataCaptureConfig @@ -58,6 +64,7 @@ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -109,7 +116,10 @@ def __init__( container_arguments: Optional[List[str]] = None, disable_output_compression: Optional[bool] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, + instance_placement_config: Optional[Dict] = None, ): """Initializes a ``JumpStartEstimator``. @@ -124,6 +134,7 @@ def __init__( https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). + hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies @@ -343,8 +354,8 @@ def __init__( source_dir (Optional[Union[str, PipelineVariable]]): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file. If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + is preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. (Default: None). @@ -501,13 +512,37 @@ def __init__( to Amazon S3 without compression after training finishes. enable_remote_debug (bool or PipelineVariable): Optional. Specifies whether RemoteDebug is enabled for the training job + config_name (Optional[str]): + Name of the training configuration to apply to the Estimator. (Default: None). enable_session_tag_chaining (bool or PipelineVariable): Optional. Specifies whether SessionTagChaining is enabled for the training job + training_plan (str or PipelineVariable): Optional. + Specifies which training plan arn to use for the training job + instance_placement_config (dict): Optional. + Specifies UltraServer placement configuration for the training job + + .. code:: python + + instance_placement_config={ + "EnableMultipleJobs": True, + "PlacementSpecifications":[ + { + "UltraServerId": "ultraserver-1", + "InstanceCount": "2" + } + ] + } Raises: ValueError: If the model ID is not recognized by JumpStart. """ + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + def _validate_model_id_and_get_type_hook(): return validate_model_id_and_get_type( model_id=model_id, @@ -515,18 +550,20 @@ def _validate_model_id_and_get_type_hook(): region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, + hub_arn=hub_arn, ) self.model_type = _validate_model_id_and_get_type_hook() if not self.model_type: JumpStartModelsAccessor.reset_cache() self.model_type = _validate_model_id_and_get_type_hook() - if not self.model_type: + if not self.model_type and not hub_arn: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=self.model_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -581,9 +618,13 @@ def _validate_model_id_and_get_type_hook(): disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, + training_plan=training_plan, + instance_placement_config=instance_placement_config, ) + self.hub_arn = estimator_init_kwargs.hub_arn self.model_id = estimator_init_kwargs.model_id self.model_version = estimator_init_kwargs.model_version self.instance_type = estimator_init_kwargs.instance_type @@ -591,10 +632,17 @@ def _validate_model_id_and_get_type_hook(): self.tolerate_vulnerable_model = estimator_init_kwargs.tolerate_vulnerable_model self.instance_count = estimator_init_kwargs.instance_count self.region = estimator_init_kwargs.region + self.environment = estimator_init_kwargs.environment self.orig_predictor_cls = None self.role = estimator_init_kwargs.role self.sagemaker_session = estimator_init_kwargs.sagemaker_session self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation + self.config_name = estimator_init_kwargs.config_name + self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False) + # Access configs initialized to None, would be given a value when .fit() is called + # if applicable + self.model_access_config = None + self.hub_access_config = None super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict()) @@ -605,6 +653,7 @@ def fit( logs: Optional[str] = None, job_name: Optional[str] = None, experiment_config: Optional[Dict[str, str]] = None, + accept_eula: Optional[bool] = None, ) -> None: """Start training job by calling base ``Estimator`` class ``fit`` method. @@ -655,11 +704,20 @@ def fit( is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. However, the value of `TrialComponentDisplayName` is honored for display in Studio. (Default: None). + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). """ - + self.model_access_config = get_model_access_config(accept_eula, self.environment) + self.hub_access_config = get_hub_access_config( + hub_content_arn=self.init_kwargs.get("model_reference_arn", None) + ) estimator_fit_kwargs = get_fit_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, inputs=inputs, wait=wait, @@ -669,6 +727,11 @@ def fit( tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, sagemaker_session=self.sagemaker_session, + config_name=self.config_name, + hub_access_config=self.hub_access_config, + ) + remove_env_var_from_estimator_kwargs_if_model_access_config_present( + self.init_kwargs, self.model_access_config ) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) @@ -679,8 +742,10 @@ def attach( training_job_name: str, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_channel_name: str = "model", + config_name: Optional[str] = None, ) -> "JumpStartEstimator": """Attach to an existing training job. @@ -716,6 +781,8 @@ def attach( model data will be downloaded (default: 'model'). If no channel with the same name exists in the training job, this option will be ignored. + config_name (str): Optional. Name of the training configuration to use + when attaching to the training job. (Default: None). Returns: Instance of the calling ``JumpStartEstimator`` Class with the attached @@ -725,25 +792,34 @@ def attach( ValueError: if the model ID or version cannot be inferred from the training job. """ - + config_name = None if model_id is None: - - model_id, model_version = get_model_id_version_from_training_job( + model_id, model_version, _, config_name = get_model_info_from_training_job( training_job_name=training_job_name, sagemaker_session=sagemaker_session ) model_version = model_version or "*" - additional_kwargs = {"model_id": model_id, "model_version": model_version} + additional_kwargs = { + "model_id": model_id, + "model_version": model_version, + "tolerate_vulnerable_model": True, # model is already trained + "tolerate_deprecated_model": True, # model is already trained + } + + if config_name: + additional_kwargs.update({"config_name": config_name}) model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=sagemaker_session.boto_region_name, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable sagemaker_session=sagemaker_session, + config_name=config_name, ) # eula was already accepted if the model was successfully trained @@ -778,7 +854,7 @@ def deploy( explainer_config: Optional[ExplainerConfig] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, model_name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, @@ -793,6 +869,7 @@ def deploy( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, use_compiled_model: bool = False, + inference_config_name: Optional[str] = None, ) -> PredictorBase: """Creates endpoint from training job. @@ -878,7 +955,7 @@ def deploy( It can be null if this is being used to create a Model to pass to a ``PipelineModel`` which has its own Role field. (Default: None). - predictor_cls (Optional[callable[string, sagemaker.session.Session]]): A + predictor_cls (Optional[Callable[[string, sagemaker.session.Session], Any]]): A function to call to create a predictor (Default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. (Default: None). @@ -907,8 +984,8 @@ def deploy( source_dir (Optional[str]): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (Default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory is + preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. If the directory points to S3, no code is uploaded and the S3 location is used instead. (Default: None). @@ -1028,6 +1105,8 @@ def deploy( (Default: None). use_compiled_model (bool): Flag to select whether to use compiled (optimized) model. (Default: False). + inference_config_name (Optional[str]): Name of the inference configuration to + be used in the model. (Default: None). """ self.orig_predictor_cls = predictor_cls @@ -1042,6 +1121,7 @@ def deploy( estimator_deploy_kwargs = get_deploy_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, @@ -1080,6 +1160,8 @@ def deploy( git_config=git_config, use_compiled_model=use_compiled_model, training_instance_type=self.instance_type, + training_config_name=self.config_name, + inference_config_name=inference_config_name, ) predictor = super(JumpStartEstimator, self).deploy( @@ -1092,15 +1174,48 @@ def deploy( predictor=predictor, model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, + config_name=estimator_deploy_kwargs.config_name, ) # If a predictor class was passed, do not mutate predictor return predictor + def list_training_configs(self) -> List[JumpStartMetadataConfig]: + """Returns a list of configs associated with the estimator. + + Raises: + ValueError: If the combination of arguments specified is not supported. + VulnerableJumpStartModelError: If any of the dependencies required by the script have + known security vulnerabilities. + DeprecatedJumpStartModelError: If the version of the model is deprecated. + """ + configs_dict = get_jumpstart_configs( + model_id=self.model_id, + model_version=self.model_version, + model_type=self.model_type, + region=self.region, + scope=JumpStartScriptScope.TRAINING, + sagemaker_session=self.sagemaker_session, + ) + return list(configs_dict.values()) + + def set_training_config(self, config_name: str) -> None: + """Sets the config to apply to the model. + + Args: + config_name (str): The name of the config. + """ + self.__init__( + model_id=self.model_id, + model_version=self.model_version, + config_name=config_name, + ) + def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self) diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 742a6b8d3f..13994c2ed9 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -32,8 +32,8 @@ ) INVALID_MODEL_ID_ERROR_MSG = ( - "Invalid model ID: '{model_id}'. Please visit " - f"{MODEL_ID_LIST_WEB_URL} for a list of valid model IDs. " + "Invalid model ID: '{model_id}'. Specify a different model ID or try a different AWS Region. " + f"For a list of available models, see {MODEL_ID_LIST_WEB_URL}. " "The module `sagemaker.jumpstart.notebook_utils` contains utilities for " "fetching model IDs. We recommend upgrading to the latest version of sagemaker " "to get access to the most models." diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 387a4a843c..81e1356050 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -14,7 +14,7 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union from sagemaker import ( environment_variables, hyperparameters as hyperparameters_utils, @@ -29,6 +29,14 @@ _retrieve_model_package_model_artifact_s3_uri, ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base +from sagemaker.jumpstart.factory.utils import ( + _set_temp_sagemaker_session_if_not_set, + get_model_info_default_kwargs, +) +from sagemaker.jumpstart.hub.utils import ( + construct_hub_model_arn_from_inputs, + construct_hub_model_reference_arn_from_inputs, +) from sagemaker.session import Session from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -44,15 +52,16 @@ _model_supports_training_model_uri, ) from sagemaker.jumpstart.constants import ( - DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, + JUMPSTART_MODEL_HUB_NAME, TRAINING_ENTRY_POINT_SCRIPT_NAME, SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY, ) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.factory import model from sagemaker.jumpstart.types import ( + HubContentType, JumpStartEstimatorDeployKwargs, JumpStartEstimatorFitKwargs, JumpStartEstimatorInitKwargs, @@ -61,8 +70,10 @@ JumpStartModelInitKwargs, ) from sagemaker.jumpstart.utils import ( - add_jumpstart_model_id_version_tags, - get_eula_message, + add_hub_content_arn_tags, + add_jumpstart_model_info_tags, + get_default_jumpstart_session_with_user_agent_suffix, + get_top_ranked_config_name, update_dict_if_key_not_present, resolve_estimator_sagemaker_config_field, verify_model_region_and_return_specs, @@ -78,6 +89,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, @@ -130,13 +142,17 @@ def get_init_kwargs( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, + instance_placement_config: Optional[Dict] = None, ) -> JumpStartEstimatorInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object.""" estimator_init_kwargs: JumpStartEstimatorInitKwargs = JumpStartEstimatorInitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, role=role, region=region, @@ -189,15 +205,40 @@ def get_init_kwargs( disable_output_compression=disable_output_compression, enable_infra_check=enable_infra_check, enable_remote_debug=enable_remote_debug, + config_name=config_name, enable_session_tag_chaining=enable_session_tag_chaining, + training_plan=training_plan, + instance_placement_config=instance_placement_config, + ) + + estimator_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set( + kwargs=estimator_init_kwargs + ) + estimator_init_kwargs.specs = verify_model_region_and_return_specs( + **get_model_info_default_kwargs( + estimator_init_kwargs, include_model_version=False, include_tolerate_flags=False + ), + version=estimator_init_kwargs.model_version or "*", + scope=JumpStartScriptScope.TRAINING, + # We set these flags to True to retrieve the json specs. + # Exceptions will be thrown later if these are not tolerated. + tolerate_deprecated_model=True, + tolerate_vulnerable_model=True, ) estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs) - estimator_init_kwargs = _add_sagemaker_session_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( + estimator_init_kwargs, orig_session + ) estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs) + if hub_arn: + estimator_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=estimator_init_kwargs) + else: + estimator_init_kwargs.model_reference_arn = None + estimator_init_kwargs.hub_content_type = None estimator_init_kwargs = _add_model_uri_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_source_dir_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_entry_point_to_kwargs(estimator_init_kwargs) @@ -207,6 +248,7 @@ def get_init_kwargs( estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_tags_to_kwargs(estimator_init_kwargs) + estimator_init_kwargs = _add_config_name_to_kwargs(estimator_init_kwargs) return estimator_init_kwargs @@ -214,6 +256,7 @@ def get_init_kwargs( def get_fit_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, wait: Optional[bool] = None, @@ -223,12 +266,15 @@ def get_fit_kwargs( tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + config_name: Optional[str] = None, + hub_access_config: Optional[Dict] = None, ) -> JumpStartEstimatorFitKwargs: """Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object.""" estimator_fit_kwargs: JumpStartEstimatorFitKwargs = JumpStartEstimatorFitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, inputs=inputs, wait=wait, @@ -238,19 +284,71 @@ def get_fit_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, + config_name=config_name, + ) + + estimator_fit_kwargs, _ = _set_temp_sagemaker_session_if_not_set(kwargs=estimator_fit_kwargs) + estimator_fit_kwargs.specs = verify_model_region_and_return_specs( + **get_model_info_default_kwargs( + estimator_fit_kwargs, include_model_version=False, include_tolerate_flags=False + ), + version=estimator_fit_kwargs.model_version or "*", + scope=JumpStartScriptScope.TRAINING, + # We set these flags to True to retrieve the json specs. + # Exceptions will be thrown later if these are not tolerated. + tolerate_deprecated_model=True, + tolerate_vulnerable_model=True, ) estimator_fit_kwargs = _add_model_version_to_kwargs(estimator_fit_kwargs) estimator_fit_kwargs = _add_region_to_kwargs(estimator_fit_kwargs) estimator_fit_kwargs = _add_training_job_name_to_kwargs(estimator_fit_kwargs) estimator_fit_kwargs = _add_fit_extra_kwargs(estimator_fit_kwargs) + estimator_fit_kwargs = _add_hub_access_config_to_kwargs_inputs( + estimator_fit_kwargs, hub_access_config + ) return estimator_fit_kwargs +def _add_hub_access_config_to_kwargs_inputs( + kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None +): + """Adds HubAccessConfig to kwargs inputs""" + + dataset_uri = kwargs.specs.default_training_dataset_uri + if isinstance(kwargs.inputs, str): + if dataset_uri is not None and dataset_uri == kwargs.inputs: + kwargs.inputs = TrainingInput( + s3_data=kwargs.inputs, hub_access_config=hub_access_config + ) + elif isinstance(kwargs.inputs, TrainingInput): + if ( + dataset_uri is not None + and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ): + kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config) + elif isinstance(kwargs.inputs, dict): + for k, v in kwargs.inputs.items(): + if isinstance(v, str): + training_input = TrainingInput(s3_data=v) + if dataset_uri is not None and dataset_uri == v: + training_input.add_hub_access_config(hub_access_config=hub_access_config) + kwargs.inputs[k] = training_input + elif isinstance(kwargs.inputs, TrainingInput): + if ( + dataset_uri is not None + and dataset_uri == kwargs.inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ): + kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config) + + return kwargs + + def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -258,6 +356,7 @@ def get_deploy_kwargs( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, @@ -271,7 +370,7 @@ def get_deploy_kwargs( explainer_config: Optional[ExplainerConfig] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, sagemaker_session: Optional[Session] = None, @@ -289,12 +388,15 @@ def get_deploy_kwargs( use_compiled_model: Optional[bool] = None, model_name: Optional[str] = None, training_instance_type: Optional[str] = None, + training_config_name: Optional[str] = None, + inference_config_name: Optional[str] = None, ) -> JumpStartEstimatorDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object.""" model_deploy_kwargs: JumpStartModelDeployKwargs = model.get_deploy_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -302,6 +404,7 @@ def get_deploy_kwargs( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, + inference_component_name=inference_component_name, tags=format_tags(tags), kms_key=kms_key, wait=wait, @@ -316,13 +419,21 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + training_config_name=training_config_name, + config_name=inference_config_name, ) model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs( model_id=model_id, model_from_estimator=True, model_version=model_version, - instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None, + hub_arn=hub_arn, + instance_type=( + model_deploy_kwargs.instance_type + if training_instance_type is None + or instance_type is not None # always use supplied inference instance type + else None + ), region=region, image_uri=image_uri, source_dir=source_dir, @@ -344,11 +455,13 @@ def get_deploy_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, training_instance_type=training_instance_type, disable_instance_type_logging=True, + config_name=model_deploy_kwargs.config_name, ) estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( model_id=model_init_kwargs.model_id, model_version=model_init_kwargs.model_version, + hub_arn=hub_arn, instance_type=model_init_kwargs.instance_type, initial_instance_count=model_deploy_kwargs.initial_instance_count, region=model_init_kwargs.region, @@ -388,6 +501,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=model_init_kwargs.tolerate_vulnerable_model, tolerate_deprecated_model=model_init_kwargs.tolerate_deprecated_model, use_compiled_model=use_compiled_model, + config_name=model_deploy_kwargs.config_name, ) return estimator_deploy_kwargs @@ -401,9 +515,16 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: return kwargs -def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: +def _add_sagemaker_session_with_custom_user_agent_to_kwargs( + kwargs: JumpStartKwargs, orig_session: Optional[Session] +) -> JumpStartKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" - kwargs.sagemaker_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION + kwargs.sagemaker_session = orig_session or get_default_jumpstart_session_with_user_agent_suffix( + model_id=kwargs.model_id, + model_version=kwargs.model_version, + config_name=None, + is_hub_content=kwargs.hub_arn is not None, + ) return kwargs @@ -412,6 +533,10 @@ def _add_model_version_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: kwargs.model_version = kwargs.model_version or "*" + if kwargs.hub_arn: + hub_content_version = kwargs.specs.version + kwargs.model_version = hub_content_version + return kwargs @@ -436,13 +561,7 @@ def _add_instance_type_and_count_to_kwargs( orig_instance_type = kwargs.instance_type kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - scope=JumpStartScriptScope.TRAINING, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, + **get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.TRAINING ) kwargs.instance_count = kwargs.instance_count or 1 @@ -458,20 +577,28 @@ def _add_instance_type_and_count_to_kwargs( def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs: """Sets tags in kwargs based on default or override, returns full kwargs.""" - full_model_version = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - scope=JumpStartScriptScope.TRAINING, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - ).version + full_model_version = kwargs.specs.version if kwargs.sagemaker_session.settings.include_jumpstart_tags: - kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version + kwargs.tags = add_jumpstart_model_info_tags( + kwargs.tags, + kwargs.model_id, + full_model_version, + config_name=kwargs.config_name, + scope=JumpStartScriptScope.TRAINING, ) + + if kwargs.hub_arn: + if kwargs.model_reference_arn: + hub_content_arn = construct_hub_model_reference_arn_from_inputs( + kwargs.hub_arn, kwargs.model_id, kwargs.model_version + ) + else: + hub_content_arn = construct_hub_model_arn_from_inputs( + kwargs.hub_arn, kwargs.model_id, kwargs.model_version + ) + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) + return kwargs @@ -479,53 +606,51 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE """Sets image uri in kwargs based on default or override, returns full kwargs.""" kwargs.image_uri = kwargs.image_uri or image_uris.retrieve( - region=kwargs.region, + **get_model_info_default_kwargs(kwargs), + instance_type=kwargs.instance_type, framework=None, image_scope=JumpStartScriptScope.TRAINING, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - instance_type=kwargs.instance_type, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, ) return kwargs +def _add_model_reference_arn_to_kwargs( + kwargs: JumpStartEstimatorInitKwargs, +) -> JumpStartEstimatorInitKwargs: + """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" + + hub_content_type = kwargs.specs.hub_content_type + kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None + + if hub_content_type == HubContentType.MODEL_REFERENCE: + kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs( + hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version + ) + else: + kwargs.model_reference_arn = None + return kwargs + + def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs: """Sets model uri in kwargs based on default or override, returns full kwargs.""" - - if _model_supports_training_model_uri( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, + # hub_arn is by default None unless the user specifies the hub_name + # If no hub_name is specified, it is assumed the public hub + # Training platform enforces that private hub models must use model channel + is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False + if is_private_hub or _model_supports_training_model_uri( + **get_model_info_default_kwargs(kwargs) ): default_model_uri = model_uris.retrieve( model_scope=JumpStartScriptScope.TRAINING, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - region=kwargs.region, instance_type=kwargs.instance_type, + **get_model_info_default_kwargs(kwargs), ) if ( kwargs.model_uri is not None and kwargs.model_uri != default_model_uri - and not _model_supports_incremental_training( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - ) + and not _model_supports_incremental_training(**get_model_info_default_kwargs(kwargs)) ): JUMPSTART_LOGGER.warning( "'%s' does not support incremental training but is being trained with" @@ -553,13 +678,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart """Sets source dir in kwargs based on default or override, returns full kwargs.""" kwargs.source_dir = kwargs.source_dir or script_uris.retrieve( - script_scope=JumpStartScriptScope.TRAINING, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - region=kwargs.region, - sagemaker_session=kwargs.sagemaker_session, + script_scope=JumpStartScriptScope.TRAINING, **get_model_info_default_kwargs(kwargs) ) return kwargs @@ -571,25 +690,15 @@ def _add_env_to_kwargs( """Sets environment in kwargs based on default or override, returns full kwargs.""" extra_env_vars = environment_variables.retrieve_default( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - include_aws_sdk_env_vars=False, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, + **get_model_info_default_kwargs(kwargs), script=JumpStartScriptScope.TRAINING, instance_type=kwargs.instance_type, + include_aws_sdk_env_vars=False, ) model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, + **get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.TRAINING, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, ) if model_package_artifact_uri: @@ -604,26 +713,6 @@ def _add_env_to_kwargs( value, ) - environment = getattr(kwargs, "environment", {}) or {} - if ( - environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY) - and str(environment.get("accept_eula", "")).lower() != "true" - ): - model_specs = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - region=kwargs.region, - scope=JumpStartScriptScope.TRAINING, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - ) - if model_specs.is_gated_model(): - raise ValueError( - "Need to define ‘accept_eula'='true' within Environment. " - f"{get_eula_message(model_specs, kwargs.region)}" - ) - return kwargs @@ -643,12 +732,8 @@ def _add_training_job_name_to_kwargs( """Sets resource name based on default or override, returns full kwargs.""" default_training_job_name = _retrieve_resource_name_base( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, + **get_model_info_default_kwargs(kwargs), + scope=JumpStartScriptScope.TRAINING, ) kwargs.job_name = kwargs.job_name or ( @@ -668,12 +753,7 @@ def _add_hyperparameters_to_kwargs( ) default_hyperparameters = hyperparameters_utils.retrieve_default( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, + **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type, ) @@ -701,12 +781,7 @@ def _add_metric_definitions_to_kwargs( default_metric_definitions = ( metric_definitions_utils.retrieve_default( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, + **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type, ) or [] @@ -730,13 +805,7 @@ def _add_estimator_extra_kwargs( """Sets extra kwargs based on default or override, returns full kwargs.""" estimator_kwargs_to_add = _retrieve_estimator_init_kwargs( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - instance_type=kwargs.instance_type, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, + **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type ) for key, value in estimator_kwargs_to_add.items(): @@ -754,17 +823,23 @@ def _add_estimator_extra_kwargs( def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstimatorFitKwargs: """Sets extra kwargs based on default or override, returns full kwargs.""" - fit_kwargs_to_add = _retrieve_estimator_fit_kwargs( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - ) + fit_kwargs_to_add = _retrieve_estimator_fit_kwargs(**get_model_info_default_kwargs(kwargs)) for key, value in fit_kwargs_to_add.items(): if getattr(kwargs, key) is None: setattr(kwargs, key, value) return kwargs + + +def _add_config_name_to_kwargs( + kwargs: JumpStartEstimatorInitKwargs, +) -> JumpStartEstimatorInitKwargs: + """Sets tags in kwargs based on default or override, returns full kwargs.""" + + kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( + scope=JumpStartScriptScope.TRAINING, + **get_model_info_default_kwargs(kwargs, include_config_name=False), + ) + + return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 28746990e3..53ded3f275 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -15,7 +15,8 @@ import json -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union +from sagemaker_core.shapes import ModelAccessConfig from sagemaker import environment_variables, image_uris, instance_types, model_uris, script_uris from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -29,49 +30,72 @@ ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base from sagemaker.jumpstart.constants import ( - DEFAULT_JUMPSTART_SAGEMAKER_SESSION, INFERENCE_ENTRY_POINT_SCRIPT_NAME, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, ) +from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard +from sagemaker.jumpstart.hub.utils import ( + construct_hub_model_arn_from_inputs, + construct_hub_model_reference_arn_from_inputs, +) from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines -from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType +from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType, HubContentCapability from sagemaker.jumpstart.types import ( + HubContentType, JumpStartModelDeployKwargs, JumpStartModelInitKwargs, JumpStartModelRegisterKwargs, + JumpStartModelSpecs, ) from sagemaker.jumpstart.utils import ( - add_jumpstart_model_id_version_tags, + add_hub_content_arn_tags, + add_jumpstart_model_info_tags, + add_bedrock_store_tags, + get_default_jumpstart_session_with_user_agent_suffix, + get_top_ranked_config_name, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, verify_model_region_and_return_specs, + get_draft_model_content_bucket, ) +from sagemaker.jumpstart.factory.utils import ( + _set_temp_sagemaker_session_if_not_set, + get_model_info_default_kwargs, +) from sagemaker.model_monitor.data_capture_config import DataCaptureConfig from sagemaker.base_predictor import Predictor from sagemaker import accept_types, content_types, serializers, deserializers from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from sagemaker.session import Session -from sagemaker.utils import name_from_base, format_tags, Tags +from sagemaker.utils import ( + camel_case_to_pascal_case, + name_from_base, + format_tags, + Tags, +) from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker import resource_requirements from sagemaker.enums import EndpointType +from sagemaker.model_life_cycle import ModelLifeCycle def get_default_predictor( predictor: Predictor, model_id: str, model_version: str, + hub_arn: Optional[str], region: str, tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, sagemaker_session: Session, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -80,7 +104,7 @@ def get_default_predictor( """ # if there's a non-default predictor, do not mutate -- return as is - if type(predictor) != Predictor: # pylint: disable=C0123 + if not isinstance(predictor, Predictor): raise RuntimeError( "Can only get default predictor from base Predictor class. " f"Using Predictor class '{type(predictor).__name__}'." @@ -89,38 +113,46 @@ def get_default_predictor( predictor.serializer = serializers.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) return predictor @@ -136,11 +168,19 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni return kwargs -def _add_sagemaker_session_to_kwargs( - kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs] +def _add_sagemaker_session_with_custom_user_agent_to_kwargs( + kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs], + orig_session: Optional[Session], ) -> JumpStartModelInitKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" - kwargs.sagemaker_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION + + kwargs.sagemaker_session = orig_session or get_default_jumpstart_session_with_user_agent_suffix( + model_id=kwargs.model_id, + model_version=kwargs.model_version, + config_name=kwargs.config_name, + is_hub_content=kwargs.hub_arn is not None, + ) + return kwargs @@ -164,6 +204,10 @@ def _add_model_version_to_kwargs( kwargs.model_version = kwargs.model_version or "*" + if kwargs.hub_arn: + hub_content_version = kwargs.specs.version + kwargs.model_version = hub_content_version + return kwargs @@ -184,17 +228,10 @@ def _add_instance_type_to_kwargs( """Sets instance type based on default or override, returns full kwargs.""" orig_instance_type = kwargs.instance_type - kwargs.instance_type = kwargs.instance_type or instance_types.retrieve_default( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, + **get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.INFERENCE, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, training_instance_type=kwargs.training_instance_type, - model_type=kwargs.model_type, ) if not disable_instance_type_logging and orig_instance_type is None: @@ -203,6 +240,22 @@ def _add_instance_type_to_kwargs( kwargs.instance_type, ) + specs = kwargs.specs + + if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs: + return kwargs + + resolved_config = ( + specs.inference_configs.configs[kwargs.config_name].resolved_config + if specs.inference_configs + else None + ) + if resolved_config is None: + return kwargs + supported_instance_types = resolved_config.get("supported_inference_instance_types", []) + if kwargs.instance_type not in supported_instance_types: + JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) + return kwargs @@ -217,20 +270,32 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel return kwargs kwargs.image_uri = kwargs.image_uri or image_uris.retrieve( - region=kwargs.region, + **get_model_info_default_kwargs(kwargs), framework=None, image_scope=JumpStartScriptScope.INFERENCE, - model_id=kwargs.model_id, - model_version=kwargs.model_version, instance_type=kwargs.instance_type, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, ) return kwargs +def _add_model_reference_arn_to_kwargs( + kwargs: JumpStartModelInitKwargs, +) -> JumpStartModelInitKwargs: + """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" + + hub_content_type = kwargs.specs.hub_content_type + kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None + + if hub_content_type == HubContentType.MODEL_REFERENCE: + kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs( + hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version + ) + else: + kwargs.model_reference_arn = None + return kwargs + + def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets model data based on default or override, returns full kwargs.""" @@ -238,14 +303,10 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode kwargs.model_data = None return kwargs + model_info_kwargs = get_model_info_default_kwargs(kwargs) model_data: Union[str, dict] = kwargs.model_data or model_uris.retrieve( + **model_info_kwargs, model_scope=JumpStartScriptScope.INFERENCE, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, instance_type=kwargs.instance_type, ) @@ -280,22 +341,9 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode source_dir = kwargs.source_dir - if _model_supports_inference_script_uri( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - ): + if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)): source_dir = source_dir or script_uris.retrieve( - script_scope=JumpStartScriptScope.INFERENCE, - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, + **get_model_info_default_kwargs(kwargs), script_scope=JumpStartScriptScope.INFERENCE ) kwargs.source_dir = source_dir @@ -312,14 +360,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod entry_point = kwargs.entry_point - if _model_supports_inference_script_uri( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - ): + if _model_supports_inference_script_uri(**get_model_info_default_kwargs(kwargs)): entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME @@ -341,13 +382,8 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw env = {} extra_env_vars = environment_variables.retrieve_default( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, + **get_model_info_default_kwargs(kwargs), include_aws_sdk_env_vars=False, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, script=JumpStartScriptScope.INFERENCE, instance_type=kwargs.instance_type, ) @@ -371,15 +407,9 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt """Sets model package arn based on default or override, returns full kwargs.""" model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn( - model_id=kwargs.model_id, - model_version=kwargs.model_version, + **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type, scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, ) kwargs.model_package_arn = model_package_arn @@ -389,15 +419,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets extra kwargs based on default or override, returns full kwargs.""" - model_kwargs_to_add = _retrieve_model_init_kwargs( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - ) + model_kwargs_to_add = _retrieve_model_init_kwargs(**get_model_info_default_kwargs(kwargs)) for key, value in model_kwargs_to_add.items(): if getattr(kwargs, key) is None: @@ -425,15 +447,7 @@ def _add_endpoint_name_to_kwargs( ) -> JumpStartModelDeployKwargs: """Sets resource name based on default or override, returns full kwargs.""" - default_endpoint_name = _retrieve_resource_name_base( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - ) + default_endpoint_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs)) kwargs.endpoint_name = kwargs.endpoint_name or ( name_from_base(default_endpoint_name) if default_endpoint_name is not None else None @@ -447,15 +461,7 @@ def _add_model_name_to_kwargs( ) -> JumpStartModelInitKwargs: """Sets resource name based on default or override, returns full kwargs.""" - default_model_name = _retrieve_resource_name_base( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - ) + default_model_name = _retrieve_resource_name_base(**get_model_info_default_kwargs(kwargs)) kwargs.name = kwargs.name or ( name_from_base(default_model_name) if default_model_name is not None else None @@ -467,22 +473,33 @@ def _add_model_name_to_kwargs( def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: """Sets tags based on default or override, returns full kwargs.""" - full_model_version = verify_model_region_and_return_specs( - model_id=kwargs.model_id, - version=kwargs.model_version, - scope=JumpStartScriptScope.INFERENCE, - region=kwargs.region, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, - ).version + full_model_version = kwargs.specs.version if kwargs.sagemaker_session.settings.include_jumpstart_tags: - kwargs.tags = add_jumpstart_model_id_version_tags( - kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type + kwargs.tags = add_jumpstart_model_info_tags( + kwargs.tags, + kwargs.model_id, + full_model_version, + kwargs.model_type, + config_name=kwargs.config_name, + scope=JumpStartScriptScope.INFERENCE, ) + if kwargs.hub_arn: + if kwargs.model_reference_arn: + hub_content_arn = construct_hub_model_reference_arn_from_inputs( + kwargs.hub_arn, kwargs.model_id, kwargs.model_version + ) + else: + hub_content_arn = construct_hub_model_arn_from_inputs( + kwargs.hub_arn, kwargs.model_id, kwargs.model_version + ) + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) + + if hasattr(kwargs.specs, "capabilities") and kwargs.specs.capabilities is not None: + if HubContentCapability.BEDROCK_CONSOLE in kwargs.specs.capabilities: + kwargs.tags = add_bedrock_store_tags(kwargs.tags, compatibility="compatible") + return kwargs @@ -490,14 +507,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] """Sets extra kwargs based on default or override, returns full kwargs.""" deploy_kwargs_to_add = _retrieve_model_deploy_kwargs( - model_id=kwargs.model_id, - model_version=kwargs.model_version, - instance_type=kwargs.instance_type, - region=kwargs.region, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, + **get_model_info_default_kwargs(kwargs), instance_type=kwargs.instance_type ) for key, value in deploy_kwargs_to_add.items(): @@ -511,23 +521,114 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel """Sets the resource requirements based on the default or an override. Returns full kwargs.""" kwargs.resources = kwargs.resources or resource_requirements.retrieve_default( - region=kwargs.region, - model_id=kwargs.model_id, - model_version=kwargs.model_version, + **get_model_info_default_kwargs(kwargs), scope=JumpStartScriptScope.INFERENCE, - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, - sagemaker_session=kwargs.sagemaker_session, - model_type=kwargs.model_type, instance_type=kwargs.instance_type, ) return kwargs +def _select_inference_config_from_training_config( + specs: JumpStartModelSpecs, training_config_name: str +) -> Optional[str]: + """Selects the inference config from the training config. + + Args: + specs (JumpStartModelSpecs): The specs for the model. + training_config_name (str): The name of the training config. + + Returns: + str: The name of the inference config. + """ + if specs.training_configs: + resolved_training_config = specs.training_configs.configs.get(training_config_name) + if resolved_training_config: + return resolved_training_config.default_inference_config + + return None + + +def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: + """Sets default config name to the kwargs. Returns full kwargs. + + Raises: + ValueError: If the instance_type is not supported with the current config. + """ + + kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( + **get_model_info_default_kwargs(kwargs, include_config_name=False), + scope=JumpStartScriptScope.INFERENCE, + ) + + if kwargs.config_name is None: + return kwargs + + return kwargs + + +def _add_additional_model_data_sources_to_kwargs( + kwargs: JumpStartModelInitKwargs, +) -> JumpStartModelInitKwargs: + """Sets default additional model data sources to init kwargs""" + + specs = kwargs.specs + # Append speculative decoding data source from metadata + speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources() + for data_source in speculative_decoding_data_sources: + data_source.s3_data_source.set_bucket( + get_draft_model_content_bucket(provider=data_source.provider, region=kwargs.region) + ) + api_shape_additional_model_data_sources = ( + [ + camel_case_to_pascal_case(data_source.to_json()) + for data_source in speculative_decoding_data_sources + ] + if specs.get_speculative_decoding_s3_data_sources() + else None + ) + + kwargs.additional_model_data_sources = ( + kwargs.additional_model_data_sources or api_shape_additional_model_data_sources + ) + + return kwargs + + +def _add_config_name_to_deploy_kwargs( + kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None +) -> JumpStartModelInitKwargs: + """Sets default config name to the kwargs. Returns full kwargs. + + If a training_config_name is passed, then choose the inference config + based on the supported inference configs in that training config. + + Raises: + ValueError: If the instance_type is not supported with the current config. + """ + + if training_config_name: + + specs = kwargs.specs + default_config_name = _select_inference_config_from_training_config( + specs=specs, training_config_name=training_config_name + ) + + else: + default_config_name = kwargs.config_name or get_top_ranked_config_name( + **get_model_info_default_kwargs(kwargs, include_config_name=False), + scope=JumpStartScriptScope.INFERENCE, + ) + + kwargs.config_name = kwargs.config_name or default_config_name + + return kwargs + + def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, @@ -536,6 +637,7 @@ def get_deploy_kwargs( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, @@ -551,16 +653,23 @@ def get_deploy_kwargs( tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None, endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: Optional[EndpointType] = None, + training_config_name: Optional[str] = None, + config_name: Optional[str] = None, + routing_config: Optional[Dict[str, Any]] = None, + model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, + inference_ami_version: Optional[str] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, region=region, initial_instance_count=initial_instance_count, @@ -569,6 +678,7 @@ def get_deploy_kwargs( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, + inference_component_name=inference_component_name, tags=format_tags(tags), kms_key=kms_key, wait=wait, @@ -584,14 +694,37 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, endpoint_logging=endpoint_logging, resources=resources, + config_name=config_name, + routing_config=routing_config, + model_access_configs=model_access_configs, + inference_ami_version=inference_ami_version, + ) + deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs) + deploy_kwargs.specs = verify_model_region_and_return_specs( + **get_model_info_default_kwargs( + deploy_kwargs, include_model_version=False, include_tolerate_flags=False + ), + version=deploy_kwargs.model_version or "*", + scope=JumpStartScriptScope.INFERENCE, + # We set these flags to True to retrieve the json specs. + # Exceptions will be thrown later if these are not tolerated. + tolerate_deprecated_model=True, + tolerate_vulnerable_model=True, ) - deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_config_name_to_deploy_kwargs( + kwargs=deploy_kwargs, training_config_name=training_config_name + ) deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( + kwargs=deploy_kwargs, orig_session=orig_session + ) + deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) @@ -613,6 +746,8 @@ def get_deploy_kwargs( def get_register_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, @@ -639,12 +774,19 @@ def get_register_kwargs( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, + config_name: Optional[str] = None, + model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, + accept_eula: Optional[bool] = None, ) -> JumpStartModelRegisterKwargs: """Returns kwargs required to call `register` on `sagemaker.estimator.Model` object.""" register_kwargs = JumpStartModelRegisterKwargs( model_id=model_id, model_version=model_version, + config_name=config_name, + hub_arn=hub_arn, + model_type=model_type, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -671,23 +813,30 @@ def get_register_kwargs( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_life_cycle=model_life_cycle, + model_card=model_card, + accept_eula=accept_eula, ) - model_specs = verify_model_region_and_return_specs( - model_id=model_id, - version=model_version, - region=region, + register_kwargs.specs = verify_model_region_and_return_specs( + **get_model_info_default_kwargs( + register_kwargs, include_model_version=False, include_tolerate_flags=False + ), + version=register_kwargs.model_version or "*", scope=JumpStartScriptScope.INFERENCE, - sagemaker_session=sagemaker_session, - tolerate_deprecated_model=tolerate_deprecated_model, - tolerate_vulnerable_model=tolerate_vulnerable_model, + # We set these flags to True to retrieve the json specs. + # Exceptions will be thrown later if these are not tolerated. + tolerate_deprecated_model=True, + tolerate_vulnerable_model=True, ) register_kwargs.content_types = ( - register_kwargs.content_types or model_specs.predictor_specs.supported_content_types + register_kwargs.content_types + or register_kwargs.specs.predictor_specs.supported_content_types ) register_kwargs.response_types = ( - register_kwargs.response_types or model_specs.predictor_specs.supported_accept_types + register_kwargs.response_types + or register_kwargs.specs.predictor_specs.supported_accept_types ) return register_kwargs @@ -697,6 +846,7 @@ def get_init_kwargs( model_id: str, model_from_estimator: bool = False, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, @@ -705,7 +855,7 @@ def get_init_kwargs( image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, @@ -723,12 +873,15 @@ def get_init_kwargs( training_instance_type: Optional[str] = None, disable_instance_type_logging: bool = False, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" model_init_kwargs: JumpStartModelInitKwargs = JumpStartModelInitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, instance_type=instance_type, region=region, @@ -754,13 +907,31 @@ def get_init_kwargs( model_package_arn=model_package_arn, training_instance_type=training_instance_type, resources=resources, + config_name=config_name, + additional_model_data_sources=additional_model_data_sources, + ) + model_init_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set( + kwargs=model_init_kwargs + ) + model_init_kwargs.specs = verify_model_region_and_return_specs( + **get_model_info_default_kwargs( + model_init_kwargs, include_model_version=False, include_tolerate_flags=False + ), + version=model_init_kwargs.model_version or "*", + scope=JumpStartScriptScope.INFERENCE, + # We set these flags to True to retrieve the json specs. + # Exceptions will be thrown later if these are not tolerated. + tolerate_deprecated_model=True, + tolerate_vulnerable_model=True, ) - - model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( + kwargs=model_init_kwargs, orig_session=orig_session + ) model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) @@ -771,6 +942,12 @@ def get_init_kwargs( model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs) + if hub_arn: + model_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=model_init_kwargs) + else: + model_init_kwargs.model_reference_arn = None + model_init_kwargs.hub_content_type = None + # we use the model artifact from the training job output if not model_from_estimator: model_init_kwargs = _add_model_data_to_kwargs(kwargs=model_init_kwargs) @@ -784,4 +961,6 @@ def get_init_kwargs( model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs) + return model_init_kwargs diff --git a/src/sagemaker/jumpstart/factory/utils.py b/src/sagemaker/jumpstart/factory/utils.py new file mode 100644 index 0000000000..faf1f8886f --- /dev/null +++ b/src/sagemaker/jumpstart/factory/utils.py @@ -0,0 +1,79 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores JumpStart factory utilities.""" + +from __future__ import absolute_import +from typing import Tuple, Union + +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.types import ( + JumpStartEstimatorDeployKwargs, + JumpStartEstimatorFitKwargs, + JumpStartEstimatorInitKwargs, + JumpStartModelDeployKwargs, + JumpStartModelInitKwargs, +) +from sagemaker.session import Session + +KwargsType = Union[ + JumpStartModelDeployKwargs, + JumpStartModelInitKwargs, + JumpStartEstimatorFitKwargs, + JumpStartEstimatorInitKwargs, + JumpStartEstimatorDeployKwargs, +] + + +def get_model_info_default_kwargs( + kwargs: KwargsType, + include_config_name: bool = True, + include_model_version: bool = True, + include_tolerate_flags: bool = True, +) -> dict: + """Returns a dictionary of model info kwargs to use with JumpStart APIs.""" + + kwargs_dict = { + "model_id": kwargs.model_id, + "hub_arn": kwargs.hub_arn, + "region": kwargs.region, + "sagemaker_session": kwargs.sagemaker_session, + "model_type": kwargs.model_type, + } + if include_config_name: + kwargs_dict.update({"config_name": kwargs.config_name}) + + if include_model_version: + kwargs_dict.update({"model_version": kwargs.model_version}) + + if include_tolerate_flags: + kwargs_dict.update( + { + "tolerate_deprecated_model": kwargs.tolerate_deprecated_model, + "tolerate_vulnerable_model": kwargs.tolerate_vulnerable_model, + } + ) + + return kwargs_dict + + +def _set_temp_sagemaker_session_if_not_set(kwargs: KwargsType) -> Tuple[KwargsType, Session]: + """Sets a temporary sagemaker session if one is not set, and returns original session. + + We need to create a default JS session (without custom user agent) + in order to retrieve config name info. + """ + + orig_session = kwargs.sagemaker_session + if kwargs.sagemaker_session is None: + kwargs.sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION + return kwargs, orig_session diff --git a/src/sagemaker/jumpstart/hub/__init__.py b/src/sagemaker/jumpstart/hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/jumpstart/hub/constants.py b/src/sagemaker/jumpstart/hub/constants.py new file mode 100644 index 0000000000..e3a6b7752a --- /dev/null +++ b/src/sagemaker/jumpstart/hub/constants.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores constants related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import + +LATEST_VERSION_WILDCARD = "*" diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py new file mode 100644 index 0000000000..692966cee4 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -0,0 +1,291 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# pylint: skip-file +"""This module provides the JumpStart Hub class.""" +from __future__ import absolute_import +from datetime import datetime +import logging +from typing import Optional, Dict, List, Any, Union + +from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME +from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.session import Session + +from sagemaker.jumpstart.types import ( + HubContentType, +) +from sagemaker.jumpstart.filters import Constant, Operator, BooleanValues +from sagemaker.jumpstart.hub.utils import ( + get_hub_model_version, + get_info_from_hub_resource_arn, + construct_hub_arn_from_name, +) + +from sagemaker.jumpstart.notebook_utils import ( + list_jumpstart_models, +) + +from sagemaker.jumpstart.hub.interfaces import ( + DescribeHubResponse, + DescribeHubContentResponse, +) +from sagemaker.jumpstart.hub.constants import ( + LATEST_VERSION_WILDCARD, +) +from sagemaker.jumpstart import utils + + +class Hub: + """Class for creating and managing a curated JumpStart hub""" + + # Setting LOGGER for backward compatibility, in case users import it... + logger = LOGGER = logging.getLogger("sagemaker") + + _list_hubs_cache: List[Dict[str, Any]] = [] + + def __init__( + self, + hub_name: str, + sagemaker_session: Session, + bucket_name: Optional[str] = None, + ) -> None: + """Instantiates a SageMaker ``Hub``. + + Args: + hub_name (str): The name of the Hub to create. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. + """ + self.hub_name = hub_name + self.region = sagemaker_session.boto_region_name + self.bucket_name = bucket_name + self._sagemaker_session = ( + sagemaker_session + or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True) + ) + + def _get_latest_model_version(self, model_id: str) -> str: + """Populates the lastest version of a model from specs no matter what is passed. + + Returns model ({ model_id: str, version: str }) + """ + model_specs = utils.verify_model_region_and_return_specs( + model_id, LATEST_VERSION_WILDCARD, JumpStartScriptScope.INFERENCE, self.region + ) + return model_specs.version + + def create( + self, + description: str, + display_name: Optional[str] = None, + search_keywords: Optional[str] = None, + tags: Optional[str] = None, + ) -> Dict[str, str]: + """Creates a hub with the given description""" + curr_timestamp = datetime.now().timestamp() + + request = { + "hub_name": self.hub_name, + "hub_description": description, + "hub_display_name": display_name, + "hub_search_keywords": search_keywords, + "tags": tags, + } + + if self.bucket_name: + request["s3_storage_config"] = { + "S3OutputPath": (f"s3://{self.bucket_name}/{self.hub_name}-{curr_timestamp}") + } + + return self._sagemaker_session.create_hub(**request) + + def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse: + """Returns descriptive information about the Hub""" + + hub_description: DescribeHubResponse = self._sagemaker_session.describe_hub( + hub_name=self.hub_name if not hub_name else hub_name + ) + + return hub_description + + def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]]: + """List and paginate models from Hub.""" + next_token: Optional[str] = None + first_iteration: bool = True + hub_model_summaries: List[Dict[str, Any]] = [] + + while first_iteration or next_token: + first_iteration = False + list_hub_content_response = self._sagemaker_session.list_hub_contents(**kwargs) + hub_model_summaries.extend(list_hub_content_response.get("HubContentSummaries", [])) + next_token = list_hub_content_response.get("NextToken") + + return hub_model_summaries + + def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]: + """Lists the models and model references in this SageMaker Hub. + + This function caches the models in local memory + + **kwargs: Passed to invocation of ``Session:list_hub_contents``. + """ + response = {} + + if clear_cache: + self._list_hubs_cache = None + if self._list_hubs_cache is None: + + hub_model_reference_summaries = self._list_and_paginate_models( + **{ + "hub_name": self.hub_name, + "hub_content_type": HubContentType.MODEL_REFERENCE.value, + **kwargs, + } + ) + + hub_model_summaries = self._list_and_paginate_models( + **{ + "hub_name": self.hub_name, + "hub_content_type": HubContentType.MODEL.value, + **kwargs, + } + ) + response["hub_content_summaries"] = hub_model_reference_summaries + hub_model_summaries + response["next_token"] = None # Temporary until pagination is implemented + return response + + def list_sagemaker_public_hub_models( + self, + filter: Union[Operator, str] = Constant(BooleanValues.TRUE), + next_token: Optional[str] = None, + ) -> Dict[str, Any]: + """Lists the models and model arns from AmazonSageMakerJumpStart Public Hub. + + Args: + filter (Union[Operator, str]): Optional. The filter to apply to list models. This can be + either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``), + or simply a string filter which will get serialized into an Identity filter. + (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. + (Default: Constant(BooleanValues.TRUE)). + next_token (str): Optional. A token to resume pagination of list_inference_components. + This is currently not implemented. + """ + + response = {} + + jumpstart_public_hub_arn = construct_hub_arn_from_name( + JUMPSTART_MODEL_HUB_NAME, self.region, self._sagemaker_session + ) + + hub_content_summaries = [] + models = list_jumpstart_models(filter=filter, list_versions=True) + for model in models: + if len(model) <= 63: + info = get_info_from_hub_resource_arn(jumpstart_public_hub_arn) + hub_model_arn = ( + f"arn:{info.partition}:" + f"sagemaker:{info.region}:" + f"aws:hub-content/{info.hub_name}/" + f"{HubContentType.MODEL.value}/{model[0]}" + ) + hub_content_summary = { + "hub_content_name": model[0], + "hub_content_arn": hub_model_arn, + } + hub_content_summaries.append(hub_content_summary) + response["hub_content_summaries"] = hub_content_summaries + + response["next_token"] = None # Temporary until pagination is implemented for this function + + return response + + def delete(self) -> None: + """Deletes this SageMaker Hub.""" + return self._sagemaker_session.delete_hub(self.hub_name) + + def create_model_reference( + self, model_arn: str, model_name: Optional[str] = None, min_version: Optional[str] = None + ): + """Adds model reference to this SageMaker Hub.""" + return self._sagemaker_session.create_hub_content_reference( + hub_name=self.hub_name, + source_hub_content_arn=model_arn, + hub_content_name=model_name, + min_version=min_version, + ) + + def delete_model_reference(self, model_name: str) -> None: + """Deletes model reference from this SageMaker Hub.""" + return self._sagemaker_session.delete_hub_content_reference( + hub_name=self.hub_name, + hub_content_type=HubContentType.MODEL_REFERENCE.value, + hub_content_name=model_name, + ) + + def describe_model( + self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None + ) -> DescribeHubContentResponse: + """Describe Model or ModelReference in a Hub.""" + hub_name = hub_name or self.hub_name + + # Users only input model id, not contentType, so first try to describe with ModelReference, then with Model + try: + model_version = get_hub_model_version( + hub_model_name=model_name, + hub_model_type=HubContentType.MODEL_REFERENCE.value, + hub_name=hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=model_version, + ) + + hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL_REFERENCE.value, + ) + + except Exception as ex: + logging.info( + "Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: " + + str(ex) + ) + + # Failed to describe ModelReference, try with Model + try: + model_version = get_hub_model_version( + hub_model_name=model_name, + hub_model_type=HubContentType.MODEL.value, + hub_name=hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=model_version, + ) + + hub_content_description: Dict[str, Any] = ( + self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL.value, + ) + ) + + except Exception as ex: + # Failed with both, throw a custom error message + raise RuntimeError( + f"Cannot get details for {model_name} in Hub {hub_name}. \ + {model_name} does not exist as a Model or ModelReference in {hub_name}: \n" + + str(ex) + ) + + return DescribeHubContentResponse(hub_content_description) diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py new file mode 100644 index 0000000000..6ba5a37c3c --- /dev/null +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -0,0 +1,936 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart HubAPI requests and responses.""" +from __future__ import absolute_import + +from enum import Enum +import re +import json +import datetime + +from typing import Any, Dict, List, Union, Optional +from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.jumpstart.types import ( + HubContentType, + HubArnExtractedInfo, + JumpStartConfigComponent, + JumpStartConfigRanking, + JumpStartMetadataConfig, + JumpStartMetadataConfigs, + JumpStartPredictorSpecs, + JumpStartHyperparameter, + JumpStartDataHolderType, + JumpStartEnvironmentVariable, + JumpStartSerializablePayload, + JumpStartInstanceTypeVariants, +) +from sagemaker.jumpstart.hub.parser_utils import ( + snake_to_upper_camel, + walk_and_apply_json, +) + + +class _ComponentType(str, Enum): + """Enum for different component types.""" + + INFERENCE = "Inference" + TRAINING = "Training" + + +class HubDataHolderType(JumpStartDataHolderType): + """Base class for many Hub API interfaces.""" + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of object.""" + json_obj = {} + for att in self.__slots__: + if att in self._non_serializable_slots: + continue + if hasattr(self, att): + cur_val = getattr(self, att) + # Do not serialize null values. + if cur_val is None: + continue + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + elif isinstance(cur_val, datetime.datetime): + json_obj[att] = str(cur_val) + else: + json_obj[att] = cur_val + return json_obj + + def __str__(self) -> str: + """Returns string representation of object. + + Example: "{'content_bucket': 'bucket', 'region_name': 'us-west-2'}" + """ + + att_dict = walk_and_apply_json(self.to_json(), snake_to_upper_camel) + return f"{json.dumps(att_dict, default=lambda o: o.to_json())}" + + +class CreateHubResponse(HubDataHolderType): + """Data class for the Hub from session.create_hub()""" + + __slots__ = [ + "hub_arn", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates CreateHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.create_hub() response. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_arn: str = json_obj["HubArn"] + + +class HubContentDependency(HubDataHolderType): + """Data class for any dependencies related to hub content. + + Content can be scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["dependency_copy_path", "dependency_origin_path", "dependency_type"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentDependency object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.dependency_copy_path: Optional[str] = json_obj.get("DependencyCopyPath", "") + self.dependency_origin_path: Optional[str] = json_obj.get("DependencyOriginPath", "") + self.dependency_type: Optional[str] = json_obj.get("DependencyType", "") + + +class DescribeHubContentResponse(HubDataHolderType): + """Data class for the Hub Content from session.describe_hub_contents()""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "failure_reason", + "hub_arn", + "hub_content_arn", + "hub_content_dependencies", + "hub_content_description", + "hub_content_display_name", + "hub_content_document", + "hub_content_markdown", + "hub_content_name", + "hub_content_search_keywords", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "reference_min_version", + "hub_name", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubContentResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: datetime.datetime = json_obj["CreationTime"] + self.document_schema_version: str = json_obj["DocumentSchemaVersion"] + self.failure_reason: Optional[str] = json_obj.get("FailureReason") + self.hub_arn: str = json_obj["HubArn"] + self.hub_content_arn: str = json_obj["HubContentArn"] + self.hub_content_dependencies = [] + if "Dependencies" in json_obj: + self.hub_content_dependencies: Optional[List[HubContentDependency]] = [ + HubContentDependency(dep) for dep in json_obj.get(["Dependencies"]) + ] + self.hub_content_description: str = json_obj.get("HubContentDescription") + self.hub_content_display_name: str = json_obj.get("HubContentDisplayName") + hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn) + self._region = hub_region + self.hub_content_type: str = json_obj.get("HubContentType") + hub_content_document = json.loads(json_obj["HubContentDocument"]) + if self.hub_content_type == HubContentType.MODEL: + self.hub_content_document: HubContentDocument = HubModelDocument( + json_obj=hub_content_document, + region=self._region, + dependencies=self.hub_content_dependencies, + ) + elif self.hub_content_type == HubContentType.MODEL_REFERENCE: + self.hub_content_document: HubContentDocument = HubModelDocument( + json_obj=hub_content_document, + region=self._region, + dependencies=self.hub_content_dependencies, + ) + elif self.hub_content_type == HubContentType.NOTEBOOK: + self.hub_content_document: HubContentDocument = HubNotebookDocument( + json_obj=hub_content_document, region=self._region + ) + else: + raise ValueError( + f"[{self.hub_content_type}] is not a valid HubContentType." + f"Should be one of: {[item.name for item in HubContentType]}." + ) + + self.hub_content_markdown: str = json_obj.get("HubContentMarkdown") + self.hub_content_name: str = json_obj["HubContentName"] + self.hub_content_search_keywords: List[str] = json_obj.get("HubContentSearchKeywords") + self.hub_content_status: str = json_obj["HubContentStatus"] + self.hub_content_version: str = json_obj["HubContentVersion"] + self.hub_name: str = json_obj["HubName"] + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class HubS3StorageConfig(HubDataHolderType): + """Data class for any dependencies related to hub content. + + Includes scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["s3_output_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubS3StorageConfig object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.s3_output_path: Optional[str] = json_obj.get("S3OutputPath", "") + + +class DescribeHubResponse(HubDataHolderType): + """Data class for the Hub from session.describe_hub()""" + + __slots__ = [ + "creation_time", + "failure_reason", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + "s3_storage_config", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + + self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"]) + self.failure_reason: str = json_obj["FailureReason"] + self.hub_arn: str = json_obj["HubArn"] + hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn) + self._region = hub_region + self.hub_description: str = json_obj["HubDescription"] + self.hub_display_name: str = json_obj["HubDisplayName"] + self.hub_name: str = json_obj["HubName"] + self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"] + self.hub_status: str = json_obj["HubStatus"] + self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"]) + self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig(json_obj["S3StorageConfig"]) + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class ImportHubResponse(HubDataHolderType): + """Data class for the Hub from session.import_hub()""" + + __slots__ = [ + "hub_arn", + "hub_content_arn", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ImportHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_arn: str = json_obj["HubArn"] + self.hub_content_arn: str = json_obj["HubContentArn"] + + +class HubSummary(HubDataHolderType): + """Data class for the HubSummary from session.list_hubs()""" + + __slots__ = [ + "creation_time", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubSummary object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"]) + self.hub_arn: str = json_obj["HubArn"] + self.hub_description: str = json_obj["HubDescription"] + self.hub_display_name: str = json_obj["HubDisplayName"] + self.hub_name: str = json_obj["HubName"] + self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"] + self.hub_status: str = json_obj["HubStatus"] + self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"]) + + +class ListHubsResponse(HubDataHolderType): + """Data class for the Hub from session.list_hubs()""" + + __slots__ = [ + "hub_summaries", + "next_token", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ListHubsResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response. + """ + self.hub_summaries: List[HubSummary] = [ + HubSummary(item) for item in json_obj["HubSummaries"] + ] + self.next_token: str = json_obj["NextToken"] + + +class EcrUri(HubDataHolderType): + """Data class for ECR image uri.""" + + __slots__ = ["account", "region_name", "repository", "tag"] + + def __init__(self, uri: str): + """Instantiates EcrUri object.""" + self.from_ecr_uri(uri) + + def from_ecr_uri(self, uri: str) -> None: + """Parse a given aws ecr image uri into its various components.""" + uri_regex = ( + r"^(?:(?P[a-zA-Z0-9][\w-]*)\.dkr\.ecr\.(?P[a-zA-Z0-9][\w-]*)" + r"\.(?P[a-zA-Z0-9\.-]+))\/(?P([a-z0-9]+" + r"(?:[._-][a-z0-9]+)*\/)*[a-z0-9]+(?:[._-][a-z0-9]+)*)(:*)(?P.*)?" + ) + + parsed_image_uri = re.compile(uri_regex).match(uri) + + account = parsed_image_uri.group("account_id") + region = parsed_image_uri.group("region") + repository = parsed_image_uri.group("repository_name") + tag = parsed_image_uri.group("image_tag") + + self.account = account + self.region_name = region + self.repository = repository + self.tag = tag + + +class NotebookLocationUris(HubDataHolderType): + """Data class for Notebook Location uri.""" + + __slots__ = ["demo_notebook", "model_fit", "model_deploy"] + + def __init__(self, json_obj: Dict[str, Any]): + """Instantiates EcrUri object.""" + self.from_json(json_obj) + + def from_json(self, json_obj: str) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.demo_notebook = json_obj.get("demo_notebook") + self.model_fit = json_obj.get("model_fit") + self.model_deploy = json_obj.get("model_deploy") + + +class HubModelDocument(HubDataHolderType): + """Data class for model type HubContentDocument from session.describe_hub_content().""" + + SCHEMA_VERSION = "2.3.0" + + __slots__ = [ + "url", + "min_sdk_version", + "training_supported", + "model_types", + "capabilities", + "incremental_training_supported", + "dynamic_container_deployment_supported", + "hosting_ecr_uri", + "hosting_artifact_s3_data_type", + "hosting_artifact_compression_type", + "hosting_artifact_uri", + "hosting_prepacked_artifact_uri", + "hosting_prepacked_artifact_version", + "hosting_script_uri", + "hosting_use_script_uri", + "hosting_eula_uri", + "hosting_model_package_arn", + "inference_ami_version", + "model_subscription_link", + "inference_configs", + "inference_config_components", + "inference_config_rankings", + "training_artifact_s3_data_type", + "training_artifact_compression_type", + "training_model_package_artifact_uri", + "hyperparameters", + "inference_environment_variables", + "training_script_uri", + "training_prepacked_script_uri", + "training_prepacked_script_version", + "training_ecr_uri", + "training_metrics", + "training_artifact_uri", + "training_configs", + "training_config_components", + "training_config_rankings", + "inference_dependencies", + "training_dependencies", + "default_inference_instance_type", + "supported_inference_instance_types", + "default_training_instance_type", + "supported_training_instance_types", + "sage_maker_sdk_predictor_specifications", + "inference_volume_size", + "training_volume_size", + "inference_enable_network_isolation", + "training_enable_network_isolation", + "fine_tuning_supported", + "validation_supported", + "default_training_dataset_uri", + "resource_name_base", + "gated_bucket", + "default_payloads", + "hosting_resource_requirements", + "hosting_instance_type_variants", + "training_instance_type_variants", + "notebook_location_uris", + "model_provider_icon_uri", + "task", + "framework", + "datatype", + "license", + "contextual_help", + "model_data_download_timeout", + "container_startup_health_check_timeout", + "encrypt_inter_container_traffic", + "max_runtime_in_seconds", + "disable_output_compression", + "model_dir", + "dependencies", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__( + self, + json_obj: Dict[str, Any], + region: str, + dependencies: List[HubContentDependency] = None, + ) -> None: + """Instantiates HubModelDocument object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content document. + + Raises: + ValueError: When one of (json_obj) or (model_specs and studio_specs) is not provided. + """ + self._region = region + self.dependencies = dependencies or [] + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub model document. + """ + self.url: str = json_obj.get("Url") + self.min_sdk_version: str = json_obj.get("MinSdkVersion") + self.hosting_ecr_uri: Optional[str] = json_obj.get("HostingEcrUri") + self.hosting_artifact_uri = json_obj.get("HostingArtifactUri") + self.hosting_script_uri = json_obj.get("HostingScriptUri") + self.inference_dependencies: List[str] = json_obj.get("InferenceDependencies") + self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [ + JumpStartEnvironmentVariable(env_variable, is_hub_content=True) + for env_variable in json_obj.get("InferenceEnvironmentVariables", []) + ] + self.model_types: Optional[List[str]] = json_obj.get("ModelTypes") + self.capabilities: Optional[List[str]] = json_obj.get("Capabilities") + self.training_supported: bool = bool(json_obj.get("TrainingSupported")) + self.incremental_training_supported: bool = bool( + json_obj.get("IncrementalTrainingSupported") + ) + self.dynamic_container_deployment_supported: Optional[bool] = ( + bool(json_obj.get("DynamicContainerDeploymentSupported")) + if json_obj.get("DynamicContainerDeploymentSupported") + else None + ) + self.hosting_artifact_s3_data_type: Optional[str] = json_obj.get( + "HostingArtifactS3DataType" + ) + self.hosting_artifact_compression_type: Optional[str] = json_obj.get( + "HostingArtifactCompressionType" + ) + self.hosting_prepacked_artifact_uri: Optional[str] = json_obj.get( + "HostingPrepackedArtifactUri" + ) + self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get( + "HostingPrepackedArtifactVersion" + ) + self.hosting_use_script_uri: Optional[bool] = ( + bool(json_obj.get("HostingUseScriptUri")) + if json_obj.get("HostingUseScriptUri") is not None + else None + ) + self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") + self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") + + self.inference_ami_version: Optional[str] = json_obj.get("InferenceAmiVersion") + + self.model_subscription_link: Optional[str] = json_obj.get("ModelSubscriptionLink") + + self.inference_config_rankings = self._get_config_rankings(json_obj) + self.inference_config_components = self._get_config_components(json_obj) + self.inference_configs = self._get_configs(json_obj) + + self.default_inference_instance_type: Optional[str] = json_obj.get( + "DefaultInferenceInstanceType" + ) + self.supported_inference_instance_types: Optional[str] = json_obj.get( + "SupportedInferenceInstanceTypes" + ) + self.sage_maker_sdk_predictor_specifications: Optional[JumpStartPredictorSpecs] = ( + JumpStartPredictorSpecs( + json_obj.get("SageMakerSdkPredictorSpecifications"), + is_hub_content=True, + ) + if json_obj.get("SageMakerSdkPredictorSpecifications") + else None + ) + self.inference_volume_size: Optional[int] = json_obj.get("InferenceVolumeSize") + self.inference_enable_network_isolation: Optional[str] = json_obj.get( + "InferenceEnableNetworkIsolation", False + ) + self.fine_tuning_supported: Optional[bool] = ( + bool(json_obj.get("FineTuningSupported")) + if json_obj.get("FineTuningSupported") + else None + ) + self.validation_supported: Optional[bool] = ( + bool(json_obj.get("ValidationSupported")) + if json_obj.get("ValidationSupported") + else None + ) + self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase") + self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False)) + self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( + { + alias: JumpStartSerializablePayload(payload, is_hub_content=True) + for alias, payload in json_obj.get("DefaultPayloads").items() + } + if json_obj.get("DefaultPayloads") + else None + ) + self.hosting_resource_requirements: Optional[Dict[str, int]] = json_obj.get( + "HostingResourceRequirements", None + ) + self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( + JumpStartInstanceTypeVariants( + json_obj.get("HostingInstanceTypeVariants"), + is_hub_content=True, + ) + if json_obj.get("HostingInstanceTypeVariants") + else None + ) + self.notebook_location_uris: Optional[NotebookLocationUris] = ( + NotebookLocationUris(json_obj.get("NotebookLocationUris")) + if json_obj.get("NotebookLocationUris") + else None + ) + self.model_provider_icon_uri: Optional[str] = None # Not needed for private beta + self.task: Optional[str] = json_obj.get("Task") + self.framework: Optional[str] = json_obj.get("Framework") + self.datatype: Optional[str] = json_obj.get("Datatype") + self.license: Optional[str] = json_obj.get("License") + self.contextual_help: Optional[str] = json_obj.get("ContextualHelp") + self.model_dir: Optional[str] = json_obj.get("ModelDir") + # Deploy kwargs + self.model_data_download_timeout: Optional[str] = json_obj.get("ModelDataDownloadTimeout") + self.container_startup_health_check_timeout: Optional[str] = json_obj.get( + "ContainerStartupHealthCheckTimeout" + ) + + if self.training_supported: + self.default_training_dataset_uri: Optional[str] = json_obj.get( + "DefaultTrainingDatasetUri" + ) + self.training_model_package_artifact_uri: Optional[str] = json_obj.get( + "TrainingModelPackageArtifactUri" + ) + self.training_artifact_compression_type: Optional[str] = json_obj.get( + "TrainingArtifactCompressionType" + ) + self.training_artifact_s3_data_type: Optional[str] = json_obj.get( + "TrainingArtifactS3DataType" + ) + self.hyperparameters: List[JumpStartHyperparameter] = [] + hyperparameters: Any = json_obj.get("Hyperparameters") + if hyperparameters is not None: + self.hyperparameters.extend( + [ + JumpStartHyperparameter(hyperparameter, is_hub_content=True) + for hyperparameter in hyperparameters + ] + ) + + self.training_script_uri: Optional[str] = json_obj.get("TrainingScriptUri") + self.training_prepacked_script_uri: Optional[str] = json_obj.get( + "TrainingPrepackedScriptUri" + ) + self.training_prepacked_script_version: Optional[str] = json_obj.get( + "TrainingPrepackedScriptVersion" + ) + self.training_ecr_uri: Optional[str] = json_obj.get("TrainingEcrUri") + self._non_serializable_slots.append("training_ecr_specs") + self.training_metrics: Optional[List[Dict[str, str]]] = json_obj.get( + "TrainingMetrics", None + ) + self.training_artifact_uri: Optional[str] = json_obj.get("TrainingArtifactUri") + + self.training_config_rankings = self._get_config_rankings( + json_obj, _ComponentType.TRAINING + ) + self.training_config_components = self._get_config_components( + json_obj, _ComponentType.TRAINING + ) + self.training_configs = self._get_configs(json_obj, _ComponentType.TRAINING) + + self.training_dependencies: Optional[str] = json_obj.get("TrainingDependencies") + self.default_training_instance_type: Optional[str] = json_obj.get( + "DefaultTrainingInstanceType" + ) + self.supported_training_instance_types: Optional[str] = json_obj.get( + "SupportedTrainingInstanceTypes" + ) + self.training_volume_size: Optional[int] = json_obj.get("TrainingVolumeSize") + self.training_enable_network_isolation: Optional[str] = json_obj.get( + "TrainingEnableNetworkIsolation", False + ) + self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( + JumpStartInstanceTypeVariants( + json_obj.get("TrainingInstanceTypeVariants"), + is_hub_content=True, + ) + if json_obj.get("TrainingInstanceTypeVariants") + else None + ) + # Estimator kwargs + self.encrypt_inter_container_traffic: Optional[bool] = ( + bool(json_obj.get("EncryptInterContainerTraffic")) + if json_obj.get("EncryptInterContainerTraffic") + else None + ) + self.max_runtime_in_seconds: Optional[str] = json_obj.get("MaxRuntimeInSeconds") + self.disable_output_compression: Optional[bool] = ( + bool(json_obj.get("DisableOutputCompression")) + if json_obj.get("DisableOutputCompression") + else None + ) + + def get_schema_version(self) -> str: + """Returns schema version.""" + return self.SCHEMA_VERSION + + def get_region(self) -> str: + """Returns hub region.""" + return self._region + + def _get_config_rankings( + self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE + ) -> Optional[Dict[str, JumpStartConfigRanking]]: + """Returns config rankings.""" + config_rankings = json_obj.get(f"{component_type.value}ConfigRankings") + return ( + { + alias: JumpStartConfigRanking(ranking, is_hub_content=True) + for alias, ranking in config_rankings.items() + } + if config_rankings + else None + ) + + def _get_config_components( + self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE + ) -> Optional[Dict[str, JumpStartConfigComponent]]: + """Returns config components.""" + config_components = json_obj.get(f"{component_type.value}ConfigComponents") + return ( + { + alias: JumpStartConfigComponent(alias, config, is_hub_content=True) + for alias, config in config_components.items() + } + if config_components + else None + ) + + def _get_configs( + self, json_obj: Dict[str, Any], component_type=_ComponentType.INFERENCE + ) -> Optional[JumpStartMetadataConfigs]: + """Returns configs.""" + if not (configs := json_obj.get(f"{component_type.value}Configs")): + return None + + configs_dict = {} + for alias, config in configs.items(): + config_components = None + if isinstance(config, dict) and (component_names := config.get("ComponentNames")): + config_components = { + name: getattr(self, f"{component_type.value.lower()}_config_components").get( + name + ) + for name in component_names + } + configs_dict[alias] = JumpStartMetadataConfig( + alias, config, json_obj, config_components, is_hub_content=True + ) + + if component_type == _ComponentType.INFERENCE: + config_rankings = self.inference_config_rankings + scope = JumpStartScriptScope.INFERENCE + else: + config_rankings = self.training_config_rankings + scope = JumpStartScriptScope.TRAINING + + return JumpStartMetadataConfigs(configs_dict, config_rankings, scope) + + +class HubNotebookDocument(HubDataHolderType): + """Data class for notebook type HubContentDocument from session.describe_hub_content().""" + + SCHEMA_VERSION = "1.0.0" + + __slots__ = ["notebook_location", "dependencies", "_region"] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any], region: str) -> None: + """Instantiates HubNotebookDocument object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content document. + """ + self._region = region + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.notebook_location = json_obj["NotebookLocation"] + self.dependencies: List[HubContentDependency] = [ + HubContentDependency(dep) for dep in json_obj["Dependencies"] + ] + + def get_schema_version(self) -> str: + """Returns schema version.""" + return self.SCHEMA_VERSION + + def get_region(self) -> str: + """Returns hub region.""" + return self._region + + +HubContentDocument = Union[HubModelDocument, HubNotebookDocument] + + +class HubContentInfo(HubDataHolderType): + """Data class for the HubContentInfo from session.list_hub_contents().""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "hub_content_arn", + "hub_content_name", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "hub_content_description", + "hub_content_display_name", + "hub_content_search_keywords", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentInfo object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: str = json_obj["CreationTime"] + self.document_schema_version: str = json_obj["DocumentSchemaVersion"] + self.hub_content_arn: str = json_obj["HubContentArn"] + self.hub_content_name: str = json_obj["HubContentName"] + self.hub_content_status: str = json_obj["HubContentStatus"] + self.hub_content_type: HubContentType = HubContentType(json_obj["HubContentType"]) + self.hub_content_version: str = json_obj["HubContentVersion"] + self.hub_content_description: Optional[str] = json_obj.get("HubContentDescription") + self.hub_content_display_name: Optional[str] = json_obj.get("HubContentDisplayName") + self._region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn( + self.hub_content_arn + ) + self.hub_content_search_keywords: Optional[List[str]] = json_obj.get( + "HubContentSearchKeywords" + ) + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class ListHubContentsResponse(HubDataHolderType): + """Data class for the Hub from session.list_hub_contents()""" + + __slots__ = [ + "hub_content_summaries", + "next_token", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ImportHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_content_summaries: List[HubContentInfo] = [ + HubContentInfo(item) for item in json_obj["HubContentSummaries"] + ] + self.next_token: str = json_obj["NextToken"] diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py new file mode 100644 index 0000000000..0983122d09 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -0,0 +1,70 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# pylint: skip-file +"""This module contains utilities related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import + +import re +from typing import Any, Dict, List, Optional + + +def camel_to_snake(camel_case_string: str) -> str: + """Converts camelCase to snake_case_string using a regex. + + This regex cannot handle whitespace ("camelString TwoWords") + """ + return re.sub(r"(? str: + """Converts snake_case_string to UpperCamelCaseString.""" + upper_camel_case_string = "".join(word.title() for word in snake_case_string.split("_")) + return upper_camel_case_string + + +def walk_and_apply_json( + json_obj: Dict[Any, Any], apply, stop_keys: Optional[List[str]] = ["metrics"] +) -> Dict[Any, Any]: + """Recursively walks a json object and applies a given function to the keys. + + stop_keys (Optional[list[str]]): List of field keys that should stop the application function. + Any children of these keys will not have the application function applied to them. + """ + + def _walk_and_apply_json(json_obj, new): + if isinstance(json_obj, dict) and isinstance(new, dict): + for key, value in json_obj.items(): + new_key = apply(key) + if (stop_keys and new_key not in stop_keys) or stop_keys is None: + if isinstance(value, dict): + new[new_key] = {} + _walk_and_apply_json(value, new=new[new_key]) + elif isinstance(value, list): + new[new_key] = [] + for item in value: + _walk_and_apply_json(item, new=new[new_key]) + else: + new[new_key] = value + else: + new[new_key] = value + elif isinstance(json_obj, dict) and isinstance(new, list): + new.append(_walk_and_apply_json(json_obj, new={})) + elif isinstance(json_obj, list) and isinstance(new, dict): + new.update(json_obj) + elif isinstance(json_obj, list) and isinstance(new, list): + new.append(json_obj) + elif isinstance(json_obj, str) and isinstance(new, list): + new.append(json_obj) + return new + + return _walk_and_apply_json(json_obj, new={}) diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py new file mode 100644 index 0000000000..8070b54e87 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -0,0 +1,288 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# pylint: skip-file +"""This module stores Hub converter utilities for JumpStart.""" +from __future__ import absolute_import + +from typing import Any, Dict, List +from sagemaker.jumpstart.enums import ModelSpecKwargType, NamingConventionType +from sagemaker.s3 import parse_s3_url +from sagemaker.jumpstart.types import ( + JumpStartModelSpecs, + HubContentType, + JumpStartDataHolderType, +) +from sagemaker.jumpstart.hub.interfaces import ( + DescribeHubContentResponse, + HubModelDocument, +) +from sagemaker.jumpstart.hub.parser_utils import ( + camel_to_snake, + snake_to_upper_camel, + walk_and_apply_json, +) + + +def _to_json(dictionary: Dict[Any, Any]) -> Dict[Any, Any]: + """Convert a nested dictionary of JumpStartDataHolderType into json with UpperCamelCase keys""" + for key, value in dictionary.items(): + if issubclass(type(value), JumpStartDataHolderType): + dictionary[key] = walk_and_apply_json(value.to_json(), snake_to_upper_camel) + elif isinstance(value, list): + new_value = [] + for value_in_list in value: + new_value_in_list = value_in_list + if issubclass(type(value_in_list), JumpStartDataHolderType): + new_value_in_list = walk_and_apply_json( + value_in_list.to_json(), snake_to_upper_camel + ) + new_value.append(new_value_in_list) + dictionary[key] = new_value + elif isinstance(value, dict): + for key_in_dict, value_in_dict in value.items(): + if issubclass(type(value_in_dict), JumpStartDataHolderType): + value[key_in_dict] = walk_and_apply_json( + value_in_dict.to_json(), snake_to_upper_camel + ) + return dictionary + + +def get_model_spec_arg_keys( + arg_type: ModelSpecKwargType, + naming_convention: NamingConventionType = NamingConventionType.DEFAULT, +) -> List[str]: + """Returns a list of arg keys for a specific model spec arg type. + + Args: + arg_type (ModelSpecKwargType): Type of the model spec's kwarg. + naming_convention (NamingConventionType): Type of naming convention to return. + + Raises: + ValueError: If the naming convention is not valid. + """ + arg_keys: List[str] = [] + if arg_type == ModelSpecKwargType.DEPLOY: + arg_keys = [ + "ModelDataDownloadTimeout", + "ContainerStartupHealthCheckTimeout", + "InferenceAmiVersion", + ] + elif arg_type == ModelSpecKwargType.ESTIMATOR: + arg_keys = [ + "EncryptInterContainerTraffic", + "MaxRuntimeInSeconds", + "DisableOutputCompression", + "ModelDir", + ] + elif arg_type == ModelSpecKwargType.MODEL: + arg_keys = [] + elif arg_type == ModelSpecKwargType.FIT: + arg_keys = [] + + if naming_convention == NamingConventionType.SNAKE_CASE: + arg_keys = [camel_to_snake(key) for key in arg_keys] + elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE: + return arg_keys + else: + raise ValueError("Please provide a valid naming convention.") + return arg_keys + + +def get_model_spec_kwargs_from_hub_model_document( + arg_type: ModelSpecKwargType, + hub_content_document: Dict[str, Any], + naming_convention: NamingConventionType = NamingConventionType.UPPER_CAMEL_CASE, +) -> Dict[str, Any]: + """Returns a map of arg type to arg keys for a given hub content document. + + Args: + arg_type (ModelSpecKwargType): Type of the model spec's kwarg. + hub_content_document: A dictionary representation of hub content document. + naming_convention (NamingConventionType): Type of naming convention to return. + + """ + kwargs = dict() + keys = get_model_spec_arg_keys(arg_type, naming_convention=naming_convention) + for k in keys: + kwarg_value = hub_content_document.get(k) + if kwarg_value is not None: + kwargs[k] = kwarg_value + return kwargs + + +def make_model_specs_from_describe_hub_content_response( + response: DescribeHubContentResponse, +) -> JumpStartModelSpecs: + """Sets fields in JumpStartModelSpecs based on values in DescribeHubContentResponse + + Args: + response (Dict[str, any]): parsed DescribeHubContentResponse returned + from SageMaker:DescribeHubContent + """ + if response.hub_content_type not in {HubContentType.MODEL, HubContentType.MODEL_REFERENCE}: + raise AttributeError( + "Invalid content type, use either HubContentType.MODEL or HubContentType.MODEL_REFERENCE." + ) + region = response.get_hub_region() + specs = {} + model_id = response.hub_content_name + specs["model_id"] = model_id + specs["version"] = response.hub_content_version + hub_model_document: HubModelDocument = response.hub_content_document + specs["url"] = hub_model_document.url + specs["min_sdk_version"] = hub_model_document.min_sdk_version + specs["model_types"] = hub_model_document.model_types + specs["capabilities"] = hub_model_document.capabilities + specs["training_supported"] = bool(hub_model_document.training_supported) + specs["incremental_training_supported"] = bool( + hub_model_document.incremental_training_supported + ) + specs["hosting_ecr_uri"] = hub_model_document.hosting_ecr_uri + specs["inference_configs"] = hub_model_document.inference_configs + specs["inference_config_components"] = hub_model_document.inference_config_components + specs["inference_config_rankings"] = hub_model_document.inference_config_rankings + + if hub_model_document.hosting_artifact_uri: + _, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_artifact_uri + ) + specs["hosting_artifact_key"] = hosting_artifact_key + specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri + + if hub_model_document.hosting_script_uri: + _, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_script_uri + ) + specs["hosting_script_key"] = hosting_script_key + + specs["inference_environment_variables"] = hub_model_document.inference_environment_variables + specs["inference_vulnerable"] = False + specs["inference_dependencies"] = hub_model_document.inference_dependencies + specs["inference_vulnerabilities"] = [] + specs["training_vulnerable"] = False + specs["training_vulnerabilities"] = [] + specs["deprecated"] = False + specs["deprecated_message"] = None + specs["deprecate_warn_message"] = None + specs["usage_info_message"] = None + specs["default_inference_instance_type"] = hub_model_document.default_inference_instance_type + specs["supported_inference_instance_types"] = ( + hub_model_document.supported_inference_instance_types + ) + specs["dynamic_container_deployment_supported"] = ( + hub_model_document.dynamic_container_deployment_supported + ) + specs["hosting_resource_requirements"] = hub_model_document.hosting_resource_requirements + + specs["hosting_prepacked_artifact_key"] = None + if hub_model_document.hosting_prepacked_artifact_uri is not None: + ( + hosting_prepacked_artifact_bucket, # pylint: disable=unused-variable + hosting_prepacked_artifact_key, + ) = parse_s3_url(hub_model_document.hosting_prepacked_artifact_uri) + specs["hosting_prepacked_artifact_key"] = hosting_prepacked_artifact_key + + hub_content_document_dict: Dict[str, Any] = hub_model_document.to_json() + + specs["fit_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.FIT, hub_content_document_dict + ) + specs["model_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.MODEL, hub_content_document_dict + ) + specs["deploy_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.DEPLOY, hub_content_document_dict + ) + specs["estimator_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.ESTIMATOR, hub_content_document_dict + ) + + specs["predictor_specs"] = hub_model_document.sage_maker_sdk_predictor_specifications + default_payloads: Dict[str, Any] = {} + if hub_model_document.default_payloads is not None: + for alias, payload in hub_model_document.default_payloads.items(): + default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake) + specs["default_payloads"] = default_payloads + specs["gated_bucket"] = hub_model_document.gated_bucket + specs["inference_volume_size"] = hub_model_document.inference_volume_size + specs["inference_enable_network_isolation"] = ( + hub_model_document.inference_enable_network_isolation + ) + specs["resource_name_base"] = hub_model_document.resource_name_base + + specs["hosting_eula_key"] = None + if hub_model_document.hosting_eula_uri is not None: + hosting_eula_bucket, hosting_eula_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_eula_uri + ) + specs["hosting_eula_key"] = hosting_eula_key + + if hub_model_document.hosting_model_package_arn: + specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn} + + specs["model_subscription_link"] = hub_model_document.model_subscription_link + + specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri + + specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants + + if specs["training_supported"]: + specs["training_ecr_uri"] = hub_model_document.training_ecr_uri + ( + training_artifact_bucket, # pylint: disable=unused-variable + training_artifact_key, + ) = parse_s3_url(hub_model_document.training_artifact_uri) + specs["training_artifact_key"] = training_artifact_key + ( + training_script_bucket, # pylint: disable=unused-variable + training_script_key, + ) = parse_s3_url(hub_model_document.training_script_uri) + specs["training_script_key"] = training_script_key + + specs["training_configs"] = hub_model_document.training_configs + specs["training_config_components"] = hub_model_document.training_config_components + specs["training_config_rankings"] = hub_model_document.training_config_rankings + + specs["training_dependencies"] = hub_model_document.training_dependencies + specs["default_training_instance_type"] = hub_model_document.default_training_instance_type + specs["supported_training_instance_types"] = ( + hub_model_document.supported_training_instance_types + ) + specs["metrics"] = hub_model_document.training_metrics + specs["training_prepacked_script_key"] = None + if hub_model_document.training_prepacked_script_uri is not None: + ( + training_prepacked_script_bucket, # pylint: disable=unused-variable + training_prepacked_script_key, + ) = parse_s3_url(hub_model_document.training_prepacked_script_uri) + specs["training_prepacked_script_key"] = training_prepacked_script_key + + specs["hyperparameters"] = hub_model_document.hyperparameters + specs["training_volume_size"] = hub_model_document.training_volume_size + specs["training_enable_network_isolation"] = ( + hub_model_document.training_enable_network_isolation + ) + if hub_model_document.training_model_package_artifact_uri: + specs["training_model_package_artifact_uris"] = { + region: hub_model_document.training_model_package_artifact_uri + } + specs["training_instance_type_variants"] = ( + hub_model_document.training_instance_type_variants + ) + if hub_model_document.default_training_dataset_uri: + _, default_training_dataset_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.default_training_dataset_uri + ) + specs["default_training_dataset_key"] = default_training_dataset_key + specs["default_training_dataset_uri"] = hub_model_document.default_training_dataset_uri + return JumpStartModelSpecs(_to_json(specs), is_hub_content=True) diff --git a/src/sagemaker/jumpstart/hub/types.py b/src/sagemaker/jumpstart/hub/types.py new file mode 100644 index 0000000000..1a68f84bbc --- /dev/null +++ b/src/sagemaker/jumpstart/hub/types.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import +from typing import Dict +from dataclasses import dataclass + + +@dataclass +class S3ObjectLocation: + """Helper class for S3 object references.""" + + bucket: str + key: str + + def format_for_s3_copy(self) -> Dict[str, str]: + """Returns a dict formatted for S3 copy calls""" + return { + "Bucket": self.bucket, + "Key": self.key, + } + + def get_uri(self) -> str: + """Returns the s3 URI""" + return f"s3://{self.bucket}/{self.key}" diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py new file mode 100644 index 0000000000..0df5e9d5c3 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -0,0 +1,260 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# pylint: skip-file +"""This module contains utilities related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import +import re +from typing import Optional, List, Any +from sagemaker.session import Session +from sagemaker.utils import aws_partition +from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo +from sagemaker.jumpstart import constants +from packaging.specifiers import SpecifierSet, InvalidSpecifier +from packaging import version + +PROPRIETARY_VERSION_KEYWORD = "@marketplace-version:" + + +def _convert_str_to_optional(string: str) -> Optional[str]: + if string == "None": + string = None + return string + + +def get_info_from_hub_resource_arn( + arn: str, +) -> HubArnExtractedInfo: + """Extracts descriptive information from a Hub or HubContent Arn.""" + + match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + hub_content_type = match.group(5) + hub_content_name = match.group(6) + hub_content_version = _convert_str_to_optional(match.group(7)) + + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + hub_content_type=hub_content_type, + hub_content_name=hub_content_name, + hub_content_version=hub_content_version, + ) + + match = re.match(constants.HUB_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + ) + + +def construct_hub_arn_from_name( + hub_name: str, + region: Optional[str] = None, + session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + account_id: Optional[str] = None, +) -> str: + """Constructs a Hub arn from the Hub name using default Session values.""" + if session is None: + # session is overridden to none by some callers + session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION + + account_id = account_id or session.account_id() + region = region or session.boto_region_name + partition = aws_partition(region) + + return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}" + + +def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: + """Constructs a HubContent model arn from the Hub name, model name, and model version.""" + + info = get_info_from_hub_resource_arn(hub_arn) + arn = ( + f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" + f"{info.hub_name}/{HubContentType.MODEL.value}/{model_name}/{version}" + ) + + return arn + + +def construct_hub_model_reference_arn_from_inputs( + hub_arn: str, model_name: str, version: str +) -> str: + """Constructs a HubContent model arn from the Hub name, model name, and model version.""" + + info = get_info_from_hub_resource_arn(hub_arn) + arn = ( + f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" + f"{info.hub_name}/{HubContentType.MODEL_REFERENCE.value}/{model_name}/{version}" + ) + + return arn + + +def generate_hub_arn_for_init_kwargs( + hub_name: str, region: Optional[str] = None, session: Optional[Session] = None +): + """Generates the Hub Arn for JumpStart class args from a HubName or Arn. + + Args: + hub_name (str): HubName or HubArn from JumpStart class args + region (str): Region from JumpStart class args + session (Session): Custom SageMaker Session from JumpStart class args + """ + + hub_arn = None + if hub_name: + if hub_name == constants.JUMPSTART_MODEL_HUB_NAME: + return None + match = re.match(constants.HUB_ARN_REGEX, hub_name) + if match: + hub_arn = hub_name + else: + hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) + return hub_arn + + +def is_gated_bucket(bucket_name: str) -> bool: + """Returns true if the bucket name is the JumpStart gated bucket.""" + return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET + + +def get_hub_model_version( + hub_name: str, + hub_model_name: str, + hub_model_type: str, + hub_model_version: Optional[str] = None, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Returns available Jumpstart hub model version. + + It will attempt both a semantic HubContent version search and Marketplace version search. + If the Marketplace version is also semantic, this function will default to HubContent version. + + Raises: + ClientError: If the specified model is not found in the hub. + KeyError: If the specified model version is not found. + """ + if sagemaker_session is None: + # sagemaker_session is overridden to none by some callers + sagemaker_session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION + + try: + hub_content_summaries = _list_hub_content_versions_helper( + hub_name=hub_name, + hub_content_name=hub_model_name, + hub_content_type=hub_model_type, + sagemaker_session=sagemaker_session, + ) + except Exception as ex: + raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") + + try: + return _get_hub_model_version_for_open_weight_version( + hub_content_summaries, hub_model_version + ) + except KeyError: + marketplace_hub_content_version = _get_hub_model_version_for_marketplace_version( + hub_content_summaries, hub_model_version + ) + if marketplace_hub_content_version: + return marketplace_hub_content_version + raise + + +def _list_hub_content_versions_helper( + hub_name, hub_content_name, hub_content_type, sagemaker_session +): + all_hub_content_summaries = [] + list_hub_content_versions_response = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, hub_content_name=hub_content_name, hub_content_type=hub_content_type + ) + all_hub_content_summaries.extend(list_hub_content_versions_response.get("HubContentSummaries")) + while "NextToken" in list_hub_content_versions_response: + list_hub_content_versions_response = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, + hub_content_name=hub_content_name, + hub_content_type=hub_content_type, + next_token=list_hub_content_versions_response["NextToken"], + ) + all_hub_content_summaries.extend( + list_hub_content_versions_response.get("HubContentSummaries") + ) + return all_hub_content_summaries + + +def _get_hub_model_version_for_open_weight_version( + hub_content_summaries: List[Any], hub_model_version: Optional[str] = None +) -> str: + available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries] + + if hub_model_version == "*" or hub_model_version is None: + return str(max(version.parse(v) for v in available_model_versions)) + + try: + spec = SpecifierSet(f"=={hub_model_version}") + except InvalidSpecifier: + raise KeyError(f"Bad semantic version: {hub_model_version}") + available_versions_filtered = list(spec.filter(available_model_versions)) + if not available_versions_filtered: + raise KeyError("Model version not available in the Hub") + hub_model_version = str(max(available_versions_filtered)) + + return hub_model_version + + +def _get_hub_model_version_for_marketplace_version( + hub_content_summaries: List[Any], marketplace_version: str +) -> Optional[str]: + """Returns the HubContent version associated with the Marketplace version. + + This function will check within the HubContentSearchKeywords for the proprietary version. + """ + for model in hub_content_summaries: + model_search_keywords = model.get("HubContentSearchKeywords", []) + if _hub_search_keywords_contains_marketplace_version( + model_search_keywords, marketplace_version + ): + return model.get("HubContentVersion") + + return None + + +def _hub_search_keywords_contains_marketplace_version( + model_search_keywords: List[str], marketplace_version: str +) -> bool: + proprietary_version_keyword = next( + filter(lambda s: s.startswith(PROPRIETARY_VERSION_KEYWORD), model_search_keywords), None + ) + + if not proprietary_version_keyword: + return False + + proprietary_version = proprietary_version_keyword.lstrip(PROPRIETARY_VERSION_KEYWORD) + if proprietary_version == marketplace_version: + return True + + return False diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 4529bc11b9..7dec3d78f9 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -14,9 +14,11 @@ from __future__ import absolute_import -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Any, Union +import pandas as pd from botocore.exceptions import ClientError +from sagemaker_core.shapes import ModelAccessConfig from sagemaker import payloads from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -24,6 +26,7 @@ from sagemaker.enums import EndpointType from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor +from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import ( INVALID_MODEL_ID_ERROR_MSG, @@ -36,13 +39,27 @@ get_init_kwargs, get_register_kwargs, ) -from sagemaker.jumpstart.types import JumpStartSerializablePayload +from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint +from sagemaker.jumpstart.types import ( + JumpStartSerializablePayload, + DeploymentConfigMetadata, +) from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, + get_jumpstart_configs, + get_metrics_from_deployment_configs, + add_instance_rate_stats_to_benchmark_metrics, + deployment_config_response_data, + _deployment_config_lru_cache, + _add_model_access_configs_to_model_data_sources, ) -from sagemaker.jumpstart.constants import JUMPSTART_LOGGER +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, @@ -55,6 +72,7 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model_life_cycle import ModelLifeCycle from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements @@ -69,6 +87,7 @@ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -76,7 +95,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, @@ -92,6 +111,8 @@ def __init__( git_config: Optional[Dict[str, str]] = None, model_package_arn: Optional[str] = None, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, ): """Initializes a ``JumpStartModel``. @@ -106,6 +127,7 @@ def __init__( https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). + hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with @@ -127,7 +149,7 @@ def __init__( It can be null if this is being used to create a Model to pass to a ``PipelineModel`` which has its own Role field. (Default: None). - predictor_cls (Optional[callable[string, sagemaker.session.Session]]): A + predictor_cls (Optional[Callable[[string, sagemaker.session.Session], Any]]): A function to call to create a predictor (Default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. (Default: None). @@ -156,8 +178,8 @@ def __init__( source_dir (Optional[str]): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (Default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory is + preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. If the directory points to S3, no code is uploaded and the S3 location is used instead. (Default: None). @@ -277,10 +299,20 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). + config_name (Optional[str]): The name of the JumpStart config that can be + optionally applied to the model. + additional_model_data_sources (Optional[Dict[str, Any]]): Additional location + of SageMaker model data (default: None). Raises: ValueError: If the model ID is not recognized by JumpStart. """ + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + def _validate_model_id_and_type(): return validate_model_id_and_get_type( model_id=model_id, @@ -288,13 +320,14 @@ def _validate_model_id_and_type(): region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, + hub_arn=hub_arn, ) self.model_type = _validate_model_id_and_type() if not self.model_type: JumpStartModelsAccessor.reset_cache() self.model_type = _validate_model_id_and_type() - if not self.model_type: + if not self.model_type and not hub_arn: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) self._model_data_is_set = model_data is not None @@ -303,6 +336,7 @@ def _validate_model_id_and_type(): model_from_estimator=False, model_type=self.model_type, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -326,25 +360,44 @@ def _validate_model_id_and_type(): git_config=git_config, model_package_arn=model_package_arn, resources=resources, + config_name=config_name, + additional_model_data_sources=additional_model_data_sources, ) self.orig_predictor_cls = predictor_cls self.model_id = model_init_kwargs.model_id self.model_version = model_init_kwargs.model_version + self.hub_arn = model_init_kwargs.hub_arn self.instance_type = model_init_kwargs.instance_type self.resources = model_init_kwargs.resources self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + self.role = role + self.config_name = model_init_kwargs.config_name + self.additional_model_data_sources = model_init_kwargs.additional_model_data_sources + self.model_reference_arn = model_init_kwargs.model_reference_arn if self.model_type == JumpStartModelType.PROPRIETARY: self.log_subscription_warning() - super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) + model_init_kwargs_dict = model_init_kwargs.to_kwargs_dict() + + super(JumpStartModel, self).__init__(**model_init_kwargs_dict) self.model_package_arn = model_init_kwargs.model_package_arn + self.init_kwargs = model_init_kwargs.to_kwargs_dict(False) + + self._metadata_configs = get_jumpstart_configs( + region=self.region, + model_id=self.model_id, + model_version=self.model_version, + sagemaker_session=self.sagemaker_session, + model_type=self.model_type, + hub_arn=self.hub_arn, + ) def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" @@ -352,6 +405,7 @@ def log_subscription_warning(self) -> None: region=self.region, model_id=self.model_id, version=self.model_version, + hub_arn=self.hub_arn, model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, @@ -373,6 +427,7 @@ def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: return payloads.retrieve_all_examples( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -395,6 +450,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: return payloads.retrieve_example( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, model_type=self.model_type, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, @@ -402,6 +458,117 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload: sagemaker_session=self.sagemaker_session, ) + def set_deployment_config(self, config_name: str, instance_type: str) -> None: + """Sets the deployment config to apply to the model. + + Args: + config_name (str): + The name of the deployment config to apply to the model. + Call list_deployment_configs to see the list of config names. + instance_type (str): + The instance_type that the model will use after setting + the config. + """ + self.__init__( + model_id=self.model_id, + model_version=self.model_version, + instance_type=instance_type, + config_name=config_name, + sagemaker_session=self.sagemaker_session, + role=self.role, + ) + + @property + def deployment_config(self) -> Optional[Dict[str, Any]]: + """The deployment config that will be applied to ``This`` model. + + Returns: + Optional[Dict[str, Any]]: Deployment config. + """ + if self.config_name is None: + return None + for config in self.list_deployment_configs(): + if config.get("DeploymentConfigName") == self.config_name: + return config + return None + + @property + def benchmark_metrics(self) -> pd.DataFrame: + """Benchmark Metrics for deployment configs. + + Returns: + Benchmark Metrics: Pandas DataFrame object. + """ + df = pd.DataFrame(self._get_deployment_configs_benchmarks_data()) + blank_index = [""] * len(df) + df.index = blank_index + return df + + def display_benchmark_metrics(self, **kwargs) -> None: + """Display deployment configs benchmark metrics.""" + df = self.benchmark_metrics + + instance_type = kwargs.get("instance_type") + if instance_type: + df = df[df["Instance Type"].str.contains(instance_type)] + + print(df.to_markdown(index=False, floatfmt=".2f")) + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for ``This`` model. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + return deployment_config_response_data( + self._get_deployment_configs(self.config_name, self.instance_type) + ) + + @classmethod + def attach( + cls, + endpoint_name: str, + inference_component_name: Optional[str] = None, + model_id: Optional[str] = None, + model_version: Optional[str] = None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + hub_name: Optional[str] = None, + ) -> "JumpStartModel": + """Attaches a JumpStartModel object to an existing SageMaker Endpoint. + + The model id, version (and inference component name) can be inferred from the tags. + """ + + inferred_model_id = inferred_model_version = inferred_inference_component_name = None + + if inference_component_name is None or model_id is None or model_version is None: + ( + inferred_model_id, + inferred_model_version, + inferred_inference_component_name, + _, + _, + ) = get_model_info_from_endpoint( + endpoint_name=endpoint_name, + inference_component_name=inference_component_name, + sagemaker_session=sagemaker_session, + ) + + model_id = model_id or inferred_model_id + model_version = model_version or inferred_model_version or "*" + inference_component_name = inference_component_name or inferred_inference_component_name + + model = JumpStartModel( + model_id=model_id, + model_version=model_version, + sagemaker_session=sagemaker_session, + hub_name=hub_name, + ) + model.endpoint_name = endpoint_name + model.inference_component_name = inference_component_name + + return model + def _create_sagemaker_model( self, instance_type=None, @@ -480,6 +647,7 @@ def deploy( deserializer: Optional[BaseDeserializer] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = True, @@ -496,6 +664,9 @@ def deploy( resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, + routing_config: Optional[Dict[str, Any]] = None, + model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None, + inference_ami_version: Optional[str] = None, ) -> PredictorBase: """Creates endpoint by calling base ``Model`` class `deploy` method. @@ -590,6 +761,13 @@ def deploy( endpoint. endpoint_type (EndpointType): The type of endpoint used to deploy models. (Default: EndpointType.MODEL_BASED). + routing_config (Optional[Dict]): Settings the control how the endpoint routes + incoming traffic to the instances that the endpoint hosts. + model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require + ModelAccessConfig, provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }` + to indicate whether model terms of use have been accepted. The `accept_eula` value + must be explicitly defined as `True` in order to accept the end-user license + agreement (EULA) that some models require. (Default: None) Raises: MarketplaceModelSubscriptionError: If the caller is not subscribed to the model. @@ -599,6 +777,7 @@ def deploy( model_id=self.model_id, model_version=self.model_version, region=self.region, + hub_arn=self.hub_arn, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, initial_instance_count=initial_instance_count, @@ -607,6 +786,7 @@ def deploy( deserializer=deserializer, accelerator_type=accelerator_type, endpoint_name=endpoint_name, + inference_component_name=inference_component_name, tags=format_tags(tags), kms_key=kms_key, wait=wait, @@ -620,11 +800,16 @@ def deploy( explainer_config=explainer_config, sagemaker_session=self.sagemaker_session, accept_eula=accept_eula, + model_reference_arn=self.model_reference_arn, endpoint_logging=endpoint_logging, resources=resources, managed_instance_scaling=managed_instance_scaling, endpoint_type=endpoint_type, model_type=self.model_type, + config_name=self.config_name, + routing_config=routing_config, + model_access_configs=model_access_configs, + inference_ami_version=inference_ami_version, ) if ( self.model_type == JumpStartModelType.PROPRIETARY @@ -634,6 +819,27 @@ def deploy( f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models." ) + # No resources given to deploy() but present 'resources' key in deploy_kwargs means default + # JumpStart resource requirements are being used + if hasattr(self, "_is_sharded_model") and not resources and deploy_kwargs.resources: + if ( + self._is_sharded_model + and deploy_kwargs.resources.num_cpus + and deploy_kwargs.resources.num_cpus > 0 + ): + JUMPSTART_LOGGER.warning( + "NumOfCpuCoresRequired should be 0 for the best experience with SageMaker Fast " + "Model Loading. Overriding the requested `num_cpus` to 0." + ) + deploy_kwargs.resources.num_cpus = 0 + + self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources( + self.additional_model_data_sources, + deploy_kwargs.model_access_configs, + deploy_kwargs.model_id, + deploy_kwargs.region, + ) + try: predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict()) except ClientError as e: @@ -644,6 +850,8 @@ def deploy( model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, + config_name=self.config_name, + hub_arn=self.hub_arn, ).model_subscription_link get_proprietary_model_subscription_error(e, subscription_link) raise @@ -654,11 +862,13 @@ def deploy( predictor=predictor, model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, sagemaker_session=self.sagemaker_session, model_type=self.model_type, + config_name=self.config_name, ) # If a predictor class was passed, do not mutate predictor @@ -688,6 +898,9 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + accept_eula: Optional[bool] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -735,14 +948,28 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). - + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. """ + if model_package_group_name is None: + model_package_group_name = self.model_id + if self.model_type is JumpStartModelType.PROPRIETARY: + source_uri = self.model_package_arn + register_kwargs = get_register_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, + model_type=self.model_type, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -769,6 +996,10 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + config_name=self.config_name, + model_card=model_card, + accept_eula=accept_eula, + model_life_cycle=model_life_cycle, ) model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) @@ -786,6 +1017,92 @@ def register_deploy_wrapper(*args, **kwargs): return model_package + @_deployment_config_lru_cache + def _get_deployment_configs_benchmarks_data(self) -> Dict[str, Any]: + """Deployment configs benchmark metrics. + + Returns: + Dict[str, List[str]]: Deployment config benchmark data. + """ + return get_metrics_from_deployment_configs( + self._get_deployment_configs(None, None), + ) + + @_deployment_config_lru_cache + def _get_deployment_configs( + self, selected_config_name: Optional[str], selected_instance_type: Optional[str] + ) -> List[DeploymentConfigMetadata]: + """Retrieve deployment configs metadata. + + Args: + selected_config_name (Optional[str]): The name of the selected deployment config. + selected_instance_type (Optional[str]): The selected instance type. + """ + deployment_configs = [] + if not self._metadata_configs: + return deployment_configs + + err = None + for config_name, metadata_config in self._metadata_configs.items(): + if selected_config_name == config_name: + instance_type_to_use = selected_instance_type + else: + instance_type_to_use = metadata_config.resolved_config.get( + "default_inference_instance_type" + ) + + if metadata_config.benchmark_metrics: + ( + err, + metadata_config.benchmark_metrics, + ) = add_instance_rate_stats_to_benchmark_metrics( + self.region, metadata_config.benchmark_metrics + ) + + config_components = metadata_config.config_components.get(config_name) + image_uri = ( + ( + config_components.hosting_instance_type_variants.get("regional_aliases", {}) + .get(self.region, {}) + .get("alias_ecr_uri_1") + ) + if config_components + else self.image_uri + ) + + init_kwargs = get_init_kwargs( + config_name=config_name, + model_id=self.model_id, + instance_type=instance_type_to_use, + sagemaker_session=self.sagemaker_session, + image_uri=image_uri, + region=self.region, + model_version=self.model_version, + hub_arn=self.hub_arn, + ) + deploy_kwargs = get_deploy_kwargs( + model_id=self.model_id, + instance_type=instance_type_to_use, + sagemaker_session=self.sagemaker_session, + region=self.region, + model_version=self.model_version, + hub_arn=self.hub_arn, + ) + + deployment_config_metadata = DeploymentConfigMetadata( + config_name, + metadata_config, + init_kwargs, + deploy_kwargs, + ) + deployment_configs.append(deployment_config_metadata) + + if err and err["Code"] == "AccessDeniedException": + error_message = "Instance rate metrics will be omitted. Reason: %s" + JUMPSTART_LOGGER.warning(error_message, err["Message"]) + + return deployment_configs + def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 732493ce3b..781548b42a 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -329,9 +329,12 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin return sorted(list(model_id_version_dict.keys())) if not list_old_models: - model_id_version_dict = { - model_id: set([max(versions)]) for model_id, versions in model_id_version_dict.items() - } + for model_id, versions in model_id_version_dict.items(): + try: + model_id_version_dict.update({model_id: set([max(versions)])}) + except TypeError: + versions = [str(v) for v in versions] + model_id_version_dict.update({model_id: set([max(versions)])}) model_id_version_set: Set[Tuple[str, str]] = set() for model_id in model_id_version_dict: @@ -532,6 +535,7 @@ def get_model_url( model_version: str, region: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> str: """Retrieve web url describing pretrained model. @@ -560,5 +564,6 @@ def get_model_url( sagemaker_session=sagemaker_session, scope=JumpStartScriptScope.INFERENCE, model_type=model_type, + config_name=config_name, ) return model_specs.url diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index 595f801598..9c6716dc64 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -23,7 +23,7 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -from sagemaker.jumpstart.enums import MIMEType +from sagemaker.jumpstart.enums import JumpStartModelType, MIMEType from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, @@ -61,6 +61,8 @@ def _construct_payload( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + alias: Optional[str] = None, ) -> Optional[JumpStartSerializablePayload]: """Returns example payload from prompt. @@ -83,6 +85,8 @@ def _construct_payload( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + model_type (JumpStartModelType): The type of the model, can be open weights model or + proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if this feature is unavailable for the specified model. @@ -94,11 +98,14 @@ def _construct_payload( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if payloads is None or len(payloads) == 0: return None - payload_to_use: JumpStartSerializablePayload = list(payloads.values())[0] + payload_to_use: JumpStartSerializablePayload = ( + payloads[alias] if alias else list(payloads.values())[0] + ) prompt_key: Optional[str] = payload_to_use.prompt_key if prompt_key is None: diff --git a/src/sagemaker/jumpstart/region_config.json b/src/sagemaker/jumpstart/region_config.json new file mode 100644 index 0000000000..136bf8256c --- /dev/null +++ b/src/sagemaker/jumpstart/region_config.json @@ -0,0 +1,167 @@ +{ + "af-south-1": { + "content_bucket": "jumpstart-cache-prod-af-south-1", + "gated_content_bucket": "jumpstart-private-cache-prod-af-south-1" + }, + "ap-east-1": { + "content_bucket": "jumpstart-cache-prod-ap-east-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-east-1" + }, + "ap-east-2": { + "content_bucket": "jumpstart-cache-prod-ap-east-2", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-east-2" + }, + "ap-northeast-1": { + "content_bucket": "jumpstart-cache-prod-ap-northeast-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-1", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-1" + }, + "ap-northeast-2": { + "content_bucket": "jumpstart-cache-prod-ap-northeast-2", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-2", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-2" + }, + "ap-northeast-3": { + "content_bucket": "jumpstart-cache-prod-ap-northeast-3", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-northeast-3", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-northeast-3" + }, + "ap-south-1": { + "content_bucket": "jumpstart-cache-prod-ap-south-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-south-1", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-south-1" + }, + "ap-south-2": { + "content_bucket": "jumpstart-cache-prod-ap-south-2", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-south-2" + }, + "ap-southeast-1": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-1", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-southeast-1" + }, + "ap-southeast-2": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-2", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-2", + "neo_content_bucket": "sagemaker-sd-models-prod-ap-southeast-2" + }, + "ap-southeast-3": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-3", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-3" + }, + "ap-southeast-4": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-4", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-4" + }, + "ap-southeast-5": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-5", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-5" + }, + "ap-southeast-7": { + "content_bucket": "jumpstart-cache-prod-ap-southeast-7", + "gated_content_bucket": "jumpstart-private-cache-prod-ap-southeast-7" + }, + "ca-central-1": { + "content_bucket": "jumpstart-cache-prod-ca-central-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ca-central-1", + "neo_content_bucket": "sagemaker-sd-models-prod-ca-central-1" + }, + "ca-west-1": { + "content_bucket": "jumpstart-cache-prod-ca-west-1", + "gated_content_bucket": "jumpstart-private-cache-prod-ca-west-1" + }, + "cn-north-1": { + "content_bucket": "jumpstart-cache-prod-cn-north-1", + "gated_content_bucket": "jumpstart-private-cache-prod-cn-north-1" + }, + "cn-northwest-1": { + "content_bucket": "jumpstart-cache-prod-cn-northwest-1", + "gated_content_bucket": "jumpstart-private-cache-prod-cn-northwest-1" + }, + "eu-central-1": { + "content_bucket": "jumpstart-cache-prod-eu-central-1", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-central-1", + "neo_content_bucket": "sagemaker-sd-models-prod-eu-central-1" + }, + "eu-central-2": { + "content_bucket": "jumpstart-cache-prod-eu-central-2", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-central-2" + }, + "eu-north-1": { + "content_bucket": "jumpstart-cache-prod-eu-north-1", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-north-1", + "neo_content_bucket": "sagemaker-sd-models-prod-eu-north-1" + }, + "eu-south-1": { + "content_bucket": "jumpstart-cache-prod-eu-south-1", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-south-1" + }, + "eu-south-2": { + "content_bucket": "jumpstart-cache-prod-eu-south-2", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-south-2" + }, + "eu-west-1": { + "content_bucket": "jumpstart-cache-prod-eu-west-1", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-west-1", + "neo_content_bucket": "sagemaker-sd-models-prod-eu-west-1" + }, + "eu-west-2": { + "content_bucket": "jumpstart-cache-prod-eu-west-2", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-west-2", + "neo_content_bucket": "sagemaker-sd-models-prod-eu-west-2" + }, + "eu-west-3": { + "content_bucket": "jumpstart-cache-prod-eu-west-3", + "gated_content_bucket": "jumpstart-private-cache-prod-eu-west-3", + "neo_content_bucket": "sagemaker-sd-models-prod-eu-west-3" + }, + "il-central-1": { + "content_bucket": "jumpstart-cache-prod-il-central-1", + "gated_content_bucket": "jumpstart-private-cache-prod-il-central-1" + }, + "me-central-1": { + "content_bucket": "jumpstart-cache-prod-me-central-1", + "gated_content_bucket": "jumpstart-private-cache-prod-me-central-1" + }, + "me-south-1": { + "content_bucket": "jumpstart-cache-prod-me-south-1", + "gated_content_bucket": "jumpstart-private-cache-prod-me-south-1" + }, + "mx-central-1": { + "content_bucket": "jumpstart-cache-prod-mx-central-1", + "gated_content_bucket": "jumpstart-private-cache-prod-mx-central-1" + }, + "sa-east-1": { + "content_bucket": "jumpstart-cache-prod-sa-east-1", + "gated_content_bucket": "jumpstart-private-cache-prod-sa-east-1", + "neo_content_bucket": "sagemaker-sd-models-prod-sa-east-1" + }, + "us-east-1": { + "content_bucket": "jumpstart-cache-prod-us-east-1", + "gated_content_bucket": "jumpstart-private-cache-prod-us-east-1", + "neo_content_bucket": "sagemaker-sd-models-prod-us-east-1" + }, + "us-east-2": { + "content_bucket": "jumpstart-cache-prod-us-east-2", + "gated_content_bucket": "jumpstart-private-cache-prod-us-east-2", + "neo_content_bucket": "sagemaker-sd-models-prod-us-east-2" + }, + "us-gov-east-1": { + "content_bucket": "jumpstart-cache-prod-us-gov-east-1", + "gated_content_bucket": "jumpstart-private-cache-prod-us-gov-east-1" + }, + "us-gov-west-1": { + "content_bucket": "jumpstart-cache-prod-us-gov-west-1", + "gated_content_bucket": "jumpstart-private-cache-prod-us-gov-west-1" + }, + "us-west-1": { + "content_bucket": "jumpstart-cache-prod-us-west-1", + "gated_content_bucket": "jumpstart-private-cache-prod-us-west-1", + "neo_content_bucket": "sagemaker-sd-models-prod-us-west-1" + }, + "us-west-2": { + "content_bucket": "jumpstart-cache-prod-us-west-2", + "gated_content_bucket": "jumpstart-private-cache-prod-us-west-2", + "neo_content_bucket": "sagemaker-sd-models-prod-us-west-2" + } +} \ No newline at end of file diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index e511a052d1..0955ae9480 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -17,17 +17,17 @@ from typing import Optional, Tuple from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn +from sagemaker.jumpstart.utils import get_jumpstart_model_info_from_resource_arn from sagemaker.session import Session from sagemaker.utils import aws_partition -def get_model_id_version_from_endpoint( +def get_model_info_from_endpoint( endpoint_name: str, inference_component_name: Optional[str] = None, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str, Optional[str]]: - """Given an endpoint and optionally inference component names, return the model ID and version. +) -> Tuple[str, str, Optional[str], Optional[str], Optional[str]]: + """Optionally inference component names, return the model ID, version and config name. Infers the model ID and version based on the resource tags. Returns a tuple of the model ID and version. A third string element is included in the tuple for any inferred inference @@ -46,7 +46,9 @@ def get_model_id_version_from_endpoint( ( model_id, model_version, - ) = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 + inference_config_name, + training_config_name, + ) = _get_model_info_from_inference_component_endpoint_with_inference_component_name( # noqa E501 # pylint: disable=c0301 inference_component_name, sagemaker_session ) @@ -54,22 +56,35 @@ def get_model_id_version_from_endpoint( ( model_id, model_version, + inference_config_name, + training_config_name, inference_component_name, - ) = _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 + ) = _get_model_info_from_inference_component_endpoint_without_inference_component_name( # noqa E501 # pylint: disable=c0301 endpoint_name, sagemaker_session ) else: - model_id, model_version = _get_model_id_version_from_model_based_endpoint( + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = _get_model_info_from_model_based_endpoint( endpoint_name, inference_component_name, sagemaker_session ) - return model_id, model_version, inference_component_name + return ( + model_id, + model_version, + inference_component_name, + inference_config_name, + training_config_name, + ) -def _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( +def _get_model_info_from_inference_component_endpoint_without_inference_component_name( endpoint_name: str, sagemaker_session: Session -) -> Tuple[str, str, str]: - """Given an endpoint name, derives the model ID, version, and inferred inference component name. +) -> Tuple[str, str, str, str]: + """Derives the model ID, version, config name and inferred inference component name. This function assumes the endpoint corresponds to an inference-component-based endpoint. An endpoint is inference-component-based if and only if the associated endpoint config @@ -98,14 +113,14 @@ def _get_model_id_version_from_inference_component_endpoint_without_inference_co ) inference_component_name = inference_component_names[0] return ( - *_get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + *_get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name, sagemaker_session ), inference_component_name, ) -def _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( +def _get_model_info_from_inference_component_endpoint_with_inference_component_name( inference_component_name: str, sagemaker_session: Session ): """Returns the model ID and version inferred from a SageMaker inference component. @@ -123,9 +138,12 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo f"inference-component/{inference_component_name}" ) - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( - inference_component_arn, sagemaker_session - ) + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(inference_component_arn, sagemaker_session) if not model_id: raise ValueError( @@ -134,15 +152,15 @@ def _get_model_id_version_from_inference_component_endpoint_with_inference_compo "when retrieving default predictor for this inference component." ) - return model_id, model_version + return model_id, model_version, inference_config_name, training_config_name -def _get_model_id_version_from_model_based_endpoint( +def _get_model_info_from_model_based_endpoint( endpoint_name: str, inference_component_name: Optional[str], sagemaker_session: Session, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a model-based endpoint. +) -> Tuple[str, str, Optional[str], Optional[str]]: + """Returns the model ID, version and config name inferred from a model-based endpoint. Raises: ValueError: If an inference component name is supplied, or if the endpoint does @@ -161,9 +179,12 @@ def _get_model_id_version_from_model_based_endpoint( endpoint_arn = f"arn:{partition}:sagemaker:{region}:{account_id}:endpoint/{endpoint_name}" - model_id, model_version = get_jumpstart_model_id_version_from_resource_arn( - endpoint_arn, sagemaker_session - ) + ( + model_id, + model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(endpoint_arn, sagemaker_session) if not model_id: raise ValueError( @@ -172,14 +193,14 @@ def _get_model_id_version_from_model_based_endpoint( "predictor for this endpoint." ) - return model_id, model_version + return model_id, model_version, inference_config_name, training_config_name -def get_model_id_version_from_training_job( +def get_model_info_from_training_job( training_job_name: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[str, str]: - """Returns the model ID and version inferred from a training job. +) -> Tuple[str, str, Optional[str], Optional[str]]: + """Returns the model ID and version and config name inferred from a training job. Raises: ValueError: If the training job does not have tags from which the model ID @@ -194,9 +215,12 @@ def get_model_id_version_from_training_job( f"arn:{partition}:sagemaker:{region}:{account_id}:training-job/{training_job_name}" ) - model_id, inferred_model_version = get_jumpstart_model_id_version_from_resource_arn( - training_job_arn, sagemaker_session - ) + ( + model_id, + inferred_model_version, + inference_config_name, + training_config_name, + ) = get_jumpstart_model_info_from_resource_arn(training_job_arn, sagemaker_session) model_version = inferred_model_version or None @@ -207,4 +231,4 @@ def get_model_id_version_from_training_job( "for this training job." ) - return model_id, model_version + return model_id, model_version, inference_config_name, training_config_name diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index dae879494e..f545425a51 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -10,12 +10,22 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module stores types related to SageMaker JumpStart.""" from __future__ import absolute_import +import re from copy import deepcopy from enum import Enum -from typing import Any, Dict, List, Optional, Set, Union -from sagemaker.utils import get_instance_type_family, format_tags, Tags, deep_override_dict +from typing import Any, Callable, Dict, List, Optional, Set, Union +from sagemaker_core.shapes import ModelAccessConfig as CoreModelAccessConfig +from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard +from sagemaker.utils import ( + S3_PREFIX, + get_instance_type_family, + format_tags, + Tags, + deep_override_dict, +) from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines @@ -29,6 +39,11 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType +from sagemaker.jumpstart.hub.parser_utils import ( + camel_to_snake, + walk_and_apply_json, +) +from sagemaker.model_life_cycle import ModelLifeCycle class JumpStartDataHolderType: @@ -113,13 +128,34 @@ class JumpStartS3FileType(str, Enum): PROPRIETARY_SPECS = "proprietary_specs" +class HubType(str, Enum): + """Enum for Hub objects.""" + + HUB = "Hub" + + +class HubContentType(str, Enum): + """Enum for Hub content objects.""" + + MODEL = "Model" + NOTEBOOK = "Notebook" + MODEL_REFERENCE = "ModelReference" + + +JumpStartContentDataType = Union[JumpStartS3FileType, HubType, HubContentType] + + class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): """Data class for launched region info.""" - __slots__ = ["content_bucket", "region_name", "gated_content_bucket"] + __slots__ = ["content_bucket", "region_name", "gated_content_bucket", "neo_content_bucket"] def __init__( - self, content_bucket: str, region_name: str, gated_content_bucket: Optional[str] = None + self, + content_bucket: str, + region_name: str, + gated_content_bucket: Optional[str] = None, + neo_content_bucket: Optional[str] = None, ): """Instantiates JumpStartLaunchedRegionInfo object. @@ -128,10 +164,13 @@ def __init__( region_name (str): Name of JumpStart launched region. gated_content_bucket (Optional[str[]): Name of JumpStart gated s3 content bucket optionally associated with region. + neo_content_bucket (Optional[str]): Name of Neo service s3 content bucket + optionally associated with region. """ self.content_bucket = content_bucket self.gated_content_bucket = gated_content_bucket self.region_name = region_name + self.neo_content_bucket = neo_content_bucket class JumpStartModelHeader(JumpStartDataHolderType): @@ -177,14 +216,18 @@ class JumpStartECRSpecs(JumpStartDataHolderType): "framework_version", "py_version", "huggingface_transformers_version", + "_is_hub_content", ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartECRSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -197,6 +240,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if not json_obj: return + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) + self.framework = json_obj.get("framework") self.framework_version = json_obj.get("framework_version") self.py_version = json_obj.get("py_version") @@ -206,7 +252,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartECRSpecs object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -223,14 +273,18 @@ class JumpStartHyperparameter(JumpStartDataHolderType): "max", "exclusive_min", "exclusive_max", + "_is_hub_content", ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartHyperparameter object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of hyperparameter. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -240,6 +294,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of hyperparameter. """ + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -257,17 +313,24 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if max_val is not None: self.max = max_val + # HubContentDocument model schema does not allow exclusive min/max. + if self._is_hub_content: + return + exclusive_min_val = json_obj.get("exclusive_min") + exclusive_max_val = json_obj.get("exclusive_max") if exclusive_min_val is not None: self.exclusive_min = exclusive_min_val - - exclusive_max_val = json_obj.get("exclusive_max") if exclusive_max_val is not None: self.exclusive_max = exclusive_max_val def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartHyperparameter object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -280,14 +343,18 @@ class JumpStartEnvironmentVariable(JumpStartDataHolderType): "default", "scope", "required_for_model_class", + "_is_hub_content", ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartEnvironmentVariable object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of environment variable. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -296,7 +363,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of environment variable. """ - + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -305,7 +372,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartEnvironmentVariable object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -317,14 +388,18 @@ class JumpStartPredictorSpecs(JumpStartDataHolderType): "supported_content_types", "default_accept_type", "supported_accept_types", + "_is_hub_content", ] - def __init__(self, spec: Optional[Dict[str, Any]]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartPredictorSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of predictor specs. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: @@ -337,6 +412,8 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.default_content_type = json_obj["default_content_type"] self.supported_content_types = json_obj["supported_content_types"] self.default_accept_type = json_obj["default_accept_type"] @@ -344,7 +421,11 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartPredictorSpecs object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -357,16 +438,18 @@ class JumpStartSerializablePayload(JumpStartDataHolderType): "accept", "body", "prompt_key", + "_is_hub_content", ] - _non_serializable_slots = ["raw_payload", "prompt_key"] + _non_serializable_slots = ["raw_payload", "prompt_key", "_is_hub_content"] - def __init__(self, spec: Optional[Dict[str, Any]]): + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartSerializablePayload object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of payload specs. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: @@ -383,9 +466,11 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.raw_payload = json_obj self.content_type = json_obj["content_type"] - self.body = json_obj["body"] + self.body = json_obj.get("body") accept = json_obj.get("accept") self.prompt_key = json_obj.get("prompt_key") if accept: @@ -401,16 +486,26 @@ class JumpStartInstanceTypeVariants(JumpStartDataHolderType): __slots__ = [ "regional_aliases", + "aliases", "variants", + "_is_hub_content", ] - def __init__(self, spec: Optional[Dict[str, Any]]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartInstanceTypeVariants object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of instance type variants. """ - self.from_json(spec) + + self._is_hub_content = is_hub_content + + if self._is_hub_content: + self.from_describe_hub_content_response(spec) + else: + self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: """Sets fields in object based on json. @@ -422,14 +517,50 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return + self.aliases = None self.regional_aliases: Optional[dict] = json_obj.get("regional_aliases") self.variants: Optional[dict] = json_obj.get("variants") def to_json(self) -> Dict[str, Any]: - """Returns json representation of JumpStartInstanceTypeVariants object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + """Returns json representation of JumpStartInstance object.""" + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj + def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on DescribeHubContent response. + + Args: + response (Dict[str, Any]): Dictionary representation of instance type variants. + """ + + if response is None: + return + + response = walk_and_apply_json(response, camel_to_snake) + self.aliases: Optional[dict] = response.get("aliases") + self.regional_aliases = None + self.variants: Optional[dict] = response.get("variants") + + def regionalize( # pylint: disable=inconsistent-return-statements + self, region: str + ) -> Optional[Dict[str, Any]]: + """Returns regionalized instance type variants.""" + + if self.regional_aliases is None or self.aliases is not None: + return + aliases = self.regional_aliases.get(region, {}) + variants = {} + for instance_name, properties in self.variants.items(): + if properties.get("regional_properties") is not None: + variants.update({instance_name: properties.get("regional_properties")}) + if properties.get("properties") is not None: + variants.update({instance_name: properties.get("properties")}) + return {"Aliases": aliases, "Variants": variants} + def get_instance_specific_metric_definitions( self, instance_type: str ) -> List[JumpStartHyperparameter]: @@ -488,6 +619,19 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str instance_type=instance_type, property_name="artifact_key" ) + def get_instance_specific_training_artifact_key(self, instance_type: str) -> Optional[str]: + """Returns instance specific training artifact key. + + Returns None if a model, instance type tuple does not have specific + training artifact key. + """ + + return self._get_instance_specific_property( + instance_type=instance_type, property_name="training_artifact_uri" + ) or self._get_instance_specific_property( + instance_type=instance_type, property_name="training_artifact_key" + ) + def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]: """Returns instance specific resource requirements. @@ -627,7 +771,12 @@ def get_instance_specific_gated_model_key_env_var_value( Returns None if a model, instance type tuple does not have instance specific property. """ - return self._get_instance_specific_property(instance_type, "gated_model_key_env_var_value") + + gated_model_key_env_var_value = ( + "gated_model_env_var_uri" if self._is_hub_content else "gated_model_key_env_var_value" + ) + + return self._get_instance_specific_property(instance_type, gated_model_key_env_var_value) def get_instance_specific_default_inference_instance_type( self, instance_type: str @@ -679,7 +828,7 @@ def get_instance_specific_supported_inference_instance_types( ) ) - def get_image_uri(self, instance_type: str, region: str) -> Optional[str]: + def get_image_uri(self, instance_type: str, region: Optional[str] = None) -> Optional[str]: """Returns image uri from instance type and region. Returns None if no instance type is available or found. @@ -700,36 +849,61 @@ def get_model_package_arn(self, instance_type: str, region: str) -> Optional[str ) def _get_regional_property( - self, instance_type: str, region: str, property_name: str + self, instance_type: str, region: Optional[str], property_name: str ) -> Optional[str]: """Returns regional property from instance type and region. Returns None if no instance type is available or found. None is also returned if the metadata is improperly formatted. """ + # pylint: disable=too-many-return-statements + # if self.variants is None or (self.aliases is None and self.regional_aliases is None): + # return None - if None in [self.regional_aliases, self.variants]: + if self.variants is None: return None - regional_property_alias: Optional[str] = ( - self.variants.get(instance_type, {}).get("regional_properties", {}).get(property_name) - ) - if regional_property_alias is None: - instance_type_family = get_instance_type_family(instance_type) + if region is None and self.regional_aliases is not None: + return None - if instance_type_family in {"", None}: - return None + regional_property_alias: Optional[str] = None + regional_property_value: Optional[str] = None + if self.regional_aliases: regional_property_alias = ( - self.variants.get(instance_type_family, {}) + self.variants.get(instance_type, {}) .get("regional_properties", {}) .get(property_name) ) + else: + regional_property_value = ( + self.variants.get(instance_type, {}).get("properties", {}).get(property_name) + ) + + if regional_property_alias is None and regional_property_value is None: + instance_type_family = get_instance_type_family(instance_type) + if instance_type_family in {"", None}: + return None + if self.regional_aliases: + regional_property_alias = ( + self.variants.get(instance_type_family, {}) + .get("regional_properties", {}) + .get(property_name) + ) + else: + # if reading from HubContent, aliases are already regionalized + regional_property_value = ( + self.variants.get(instance_type_family, {}) + .get("properties", {}) + .get(property_name) + ) - if regional_property_alias is None or len(regional_property_alias) == 0: + if (regional_property_alias is None or len(regional_property_alias) == 0) and ( + regional_property_value is None or len(regional_property_value) == 0 + ): return None - if not regional_property_alias.startswith("$"): + if regional_property_alias and not regional_property_alias.startswith("$"): # No leading '$' indicates bad metadata. # There are tests to ensure this never happens. # However, to allow for fallback options in the unlikely event @@ -737,16 +911,250 @@ def _get_regional_property( # We return None, indicating the field does not exist. return None - if region not in self.regional_aliases: + if self.regional_aliases and region not in self.regional_aliases: return None - alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None) - return alias_value + + if self.regional_aliases: + alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None) + return alias_value + return regional_property_value + + +class JumpStartAdditionalDataSources(JumpStartDataHolderType): + """Data class of additional data sources.""" + + __slots__ = ["speculative_decoding", "scripts"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a AdditionalDataSources object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.speculative_decoding: Optional[List[JumpStartModelDataSource]] = ( + [ + JumpStartModelDataSource(data_source) + for data_source in json_obj["speculative_decoding"] + ] + if json_obj.get("speculative_decoding") + else None + ) + self.scripts: Optional[List[JumpStartModelDataSource]] = ( + [JumpStartModelDataSource(data_source) for data_source in json_obj["scripts"]] + if json_obj.get("scripts") + else None + ) + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of AdditionalDataSources object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj + + +class ModelAccessConfig(JumpStartDataHolderType): + """Data class of model access config that mirrors CreateModel API.""" + + __slots__ = ["accept_eula"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a ModelAccessConfig object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.accept_eula: bool = json_obj["accept_eula"] + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of ModelAccessConfig object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + return json_obj + + +class HubAccessConfig(JumpStartDataHolderType): + """Data class of model access config that mirrors CreateModel API.""" + + __slots__ = ["hub_content_arn"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a HubAccessConfig object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.hub_content_arn: bool = json_obj["accept_eula"] + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of ModelAccessConfig object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + return json_obj + + +class S3DataSource(JumpStartDataHolderType): + """Data class of S3 data source that mirrors CreateModel API.""" + + __slots__ = [ + "compression_type", + "s3_data_type", + "s3_uri", + "model_access_config", + "hub_access_config", + ] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a S3DataSource object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.compression_type: str = json_obj["compression_type"] + self.s3_data_type: str = json_obj["s3_data_type"] + self.s3_uri: str = json_obj["s3_uri"] + self.model_access_config: ModelAccessConfig = ( + ModelAccessConfig(json_obj["model_access_config"]) + if json_obj.get("model_access_config") + else None + ) + self.hub_access_config: HubAccessConfig = ( + HubAccessConfig(json_obj["hub_access_config"]) + if json_obj.get("hub_access_config") + else None + ) + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of S3DataSource object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif cur_val: + json_obj[att] = cur_val + return json_obj + + def set_bucket(self, bucket: str) -> None: + """Sets bucket name from S3 URI.""" + + if self.s3_uri.startswith(S3_PREFIX): + s3_path = self.s3_uri[len(S3_PREFIX) :] + old_bucket = s3_path.split("/")[0] + key = s3_path[len(old_bucket) :] + self.s3_uri = f"{S3_PREFIX}{bucket}{key}" # pylint: disable=W0201 + return + + if not bucket.endswith("/"): + bucket += "/" + + self.s3_uri = f"{S3_PREFIX}{bucket}{self.s3_uri}" # pylint: disable=W0201 + + +class AdditionalModelDataSource(JumpStartDataHolderType): + """Data class of additional model data source mirrors CreateModel API.""" + + SERIALIZATION_EXCLUSION_SET = {"provider"} + + __slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"] + + def __init__(self, spec: Dict[str, Any]): + """Initializes a AdditionalModelDataSource object. + + Args: + spec (Dict[str, Any]): Dictionary representation of data source. + """ + self.from_json(spec) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + self.channel_name: str = json_obj["channel_name"] + self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"]) + self.hosting_eula_key: str = json_obj.get("hosting_eula_key") + self.provider: Dict = json_obj.get("provider", {}) + + def to_json(self, exclude_keys=True) -> Dict[str, Any]: + """Returns json representation of AdditionalModelDataSource object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + if exclude_keys and att not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys: + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + else: + json_obj[att] = cur_val + return json_obj + + +class JumpStartModelDataSource(AdditionalModelDataSource): + """Data class JumpStart additional model data source.""" + + SERIALIZATION_EXCLUSION_SET = AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET.union( + {"artifact_version"} + ) + + __slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__ + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of data source. + """ + super().from_json(json_obj) + self.artifact_version: str = json_obj["artifact_version"] class JumpStartBenchmarkStat(JumpStartDataHolderType): """Data class JumpStart benchmark stat.""" - __slots__ = ["name", "value", "unit"] + __slots__ = ["name", "value", "unit", "concurrency"] def __init__(self, spec: Dict[str, Any]): """Initializes a JumpStartBenchmarkStat object. @@ -765,6 +1173,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.name: str = json_obj["name"] self.value: str = json_obj["value"] self.unit: Union[int, str] = json_obj["unit"] + self.concurrency: Union[int, str] = json_obj["concurrency"] def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartBenchmarkStat object.""" @@ -777,12 +1186,14 @@ class JumpStartConfigRanking(JumpStartDataHolderType): __slots__ = ["description", "rankings"] - def __init__(self, spec: Optional[Dict[str, Any]]): + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content=False): """Initializes a JumpStartConfigRanking object. Args: spec (Dict[str, Any]): Dictionary representation of training config ranking. """ + if is_hub_content: + spec = walk_and_apply_json(spec, camel_to_snake) self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -808,12 +1219,17 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "url", "version", "min_sdk_version", + "model_types", + "capabilities", "incremental_training_supported", "hosting_ecr_specs", + "hosting_ecr_uri", + "hosting_artifact_uri", "hosting_artifact_key", "hosting_script_key", "training_supported", "training_ecr_specs", + "training_ecr_uri", "training_artifact_key", "training_script_key", "hyperparameters", @@ -836,7 +1252,9 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "supported_training_instance_types", "metrics", "training_prepacked_script_key", + "training_prepacked_script_version", "hosting_prepacked_artifact_key", + "hosting_prepacked_artifact_version", "model_kwargs", "deploy_kwargs", "estimator_kwargs", @@ -856,14 +1274,24 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "default_payloads", "gated_bucket", "model_subscription_link", + "hosting_additional_data_sources", + "hosting_neuron_model_id", + "hosting_neuron_model_version", + "hub_content_type", + "_is_hub_content", + "default_training_dataset_key", + "default_training_dataset_uri", ] - def __init__(self, fields: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, fields: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartMetadataFields object. Args: fields (Dict[str, Any]): Dictionary representation of metadata fields. """ + self._is_hub_content = is_hub_content self.from_json(fields) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -872,6 +1300,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of spec. """ + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.model_id: str = json_obj.get("model_id") self.url: str = json_obj.get("url") self.version: str = json_obj.get("version") @@ -879,16 +1309,26 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.incremental_training_supported: bool = bool( json_obj.get("incremental_training_supported", False) ) - self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( - JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) - if "hosting_ecr_specs" in json_obj - else None - ) + if self._is_hub_content: + self.capabilities: Optional[List[str]] = json_obj.get("capabilities") + self.model_types: Optional[List[str]] = json_obj.get("model_types") + self.hosting_ecr_uri: Optional[str] = json_obj.get("hosting_ecr_uri") + self._non_serializable_slots.append("hosting_ecr_specs") + else: + self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( + JumpStartECRSpecs( + json_obj["hosting_ecr_specs"], is_hub_content=self._is_hub_content + ) + if "hosting_ecr_specs" in json_obj + else None + ) + self._non_serializable_slots.append("hosting_ecr_uri") self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key") + self.hosting_artifact_uri: Optional[str] = json_obj.get("hosting_artifact_uri") self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key") self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False)) self.inference_environment_variables = [ - JumpStartEnvironmentVariable(env_variable) + JumpStartEnvironmentVariable(env_variable, is_hub_content=self._is_hub_content) for env_variable in json_obj.get("inference_environment_variables", []) ] self.inference_vulnerable: bool = bool(json_obj.get("inference_vulnerable", False)) @@ -926,16 +1366,27 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get( "hosting_prepacked_artifact_key", None ) + # New fields required for Hub model. + if self._is_hub_content: + self.training_prepacked_script_version: Optional[str] = json_obj.get( + "training_prepacked_script_version" + ) + self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get( + "hosting_prepacked_artifact_version" + ) self.model_kwargs = deepcopy(json_obj.get("model_kwargs", {})) self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {})) self.predictor_specs: Optional[JumpStartPredictorSpecs] = ( - JumpStartPredictorSpecs(json_obj["predictor_specs"]) - if "predictor_specs" in json_obj + JumpStartPredictorSpecs( + json_obj.get("predictor_specs"), + is_hub_content=self._is_hub_content, + ) + if json_obj.get("predictor_specs") else None ) self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( { - alias: JumpStartSerializablePayload(payload) + alias: JumpStartSerializablePayload(payload, is_hub_content=self._is_hub_content) for alias, payload in json_obj["default_payloads"].items() } if json_obj.get("default_payloads") @@ -950,28 +1401,51 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key") - self.hosting_model_package_arns: Optional[Dict] = json_obj.get("hosting_model_package_arns") + model_package_arns = json_obj.get("hosting_model_package_arns") + self.hosting_model_package_arns: Optional[Dict] = ( + model_package_arns if model_package_arns is not None else {} + ) + self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True) self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( - JumpStartInstanceTypeVariants(json_obj["hosting_instance_type_variants"]) + JumpStartInstanceTypeVariants( + json_obj["hosting_instance_type_variants"], self._is_hub_content + ) if json_obj.get("hosting_instance_type_variants") else None ) + self.hosting_additional_data_sources: Optional[JumpStartAdditionalDataSources] = ( + JumpStartAdditionalDataSources(json_obj["hosting_additional_data_sources"]) + if json_obj.get("hosting_additional_data_sources") + else None + ) + self.hosting_neuron_model_id: Optional[str] = json_obj.get("hosting_neuron_model_id") + self.hosting_neuron_model_version: Optional[str] = json_obj.get( + "hosting_neuron_model_version" + ) if self.training_supported: - self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( - JumpStartECRSpecs(json_obj["training_ecr_specs"]) - if "training_ecr_specs" in json_obj - else None - ) + if self._is_hub_content: + self.training_ecr_uri: Optional[str] = json_obj.get("training_ecr_uri") + self._non_serializable_slots.append("training_ecr_specs") + else: + self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( + JumpStartECRSpecs(json_obj["training_ecr_specs"]) + if "training_ecr_specs" in json_obj + else None + ) + self._non_serializable_slots.append("training_ecr_uri") self.training_artifact_key: str = json_obj["training_artifact_key"] self.training_script_key: str = json_obj["training_script_key"] hyperparameters: Any = json_obj.get("hyperparameters") self.hyperparameters: List[JumpStartHyperparameter] = [] if hyperparameters is not None: self.hyperparameters.extend( - [JumpStartHyperparameter(hyperparameter) for hyperparameter in hyperparameters] + [ + JumpStartHyperparameter(hyperparameter, is_hub_content=self._is_hub_content) + for hyperparameter in hyperparameters + ] ) self.estimator_kwargs = deepcopy(json_obj.get("estimator_kwargs", {})) self.fit_kwargs = deepcopy(json_obj.get("fit_kwargs", {})) @@ -983,17 +1457,25 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: "training_model_package_artifact_uris" ) self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( - JumpStartInstanceTypeVariants(json_obj["training_instance_type_variants"]) + JumpStartInstanceTypeVariants( + json_obj["training_instance_type_variants"], is_hub_content=self._is_hub_content + ) if json_obj.get("training_instance_type_variants") else None ) self.model_subscription_link = json_obj.get("model_subscription_link") + self.default_training_dataset_key: Optional[str] = json_obj.get( + "default_training_dataset_key" + ) + self.default_training_dataset_uri: Optional[str] = json_obj.get( + "default_training_dataset_uri" + ) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataBaseFields object.""" json_obj = {} for att in self.__slots__: - if hasattr(self, att): + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []): cur_val = getattr(self, att) if issubclass(type(cur_val), JumpStartDataHolderType): json_obj[att] = cur_val.to_json() @@ -1015,6 +1497,11 @@ def to_json(self) -> Dict[str, Any]: json_obj[att] = cur_val return json_obj + def set_hub_content_type(self, hub_content_type: HubContentType) -> None: + """Sets the hub content type.""" + if self._is_hub_content: + self.hub_content_type = hub_content_type + class JumpStartConfigComponent(JumpStartMetadataBaseFields): """Data class of JumpStart config component.""" @@ -1036,12 +1523,13 @@ class JumpStartConfigComponent(JumpStartMetadataBaseFields): "incremental_training_supported", ] + # Map of HubContent fields that map to custom names in MetadataBaseFields + CUSTOM_FIELD_MAP = {"sage_maker_sdk_predictor_specifications": "predictor_specs"} + __slots__ = slots + JumpStartMetadataBaseFields.__slots__ def __init__( - self, - component_name: str, - component: Optional[Dict[str, Any]], + self, component_name: str, component: Optional[Dict[str, Any]], is_hub_content=False ): """Initializes a JumpStartConfigComponent object from its json representation. @@ -1052,8 +1540,10 @@ def __init__( Raises: ValueError: If the component field is invalid. """ - super().__init__(component) + if is_hub_content: + component = walk_and_apply_json(component, camel_to_snake) self.component_name = component_name + super().__init__(component, is_hub_content) self.from_json(component) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -1064,9 +1554,13 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Dictionary representation of the config component. """ for field in json_obj.keys(): - if field not in self.__slots__: - raise ValueError(f"Invalid component field: {field}") - setattr(self, field, json_obj[field]) + if field in self.__slots__: + setattr(self, field, json_obj[field]) + + # Handle custom fields + for custom_field, field in self.CUSTOM_FIELD_MAP.items(): + if custom_field in json_obj: + setattr(self, field, json_obj.get(custom_field)) class JumpStartMetadataConfig(JumpStartDataHolderType): @@ -1075,30 +1569,61 @@ class JumpStartMetadataConfig(JumpStartDataHolderType): __slots__ = [ "base_fields", "benchmark_metrics", + "acceleration_configs", "config_components", "resolved_metadata_config", + "config_name", + "default_inference_config", + "default_incremental_training_config", + "supported_inference_configs", + "supported_incremental_training_configs", ] def __init__( self, + config_name: str, + config: Dict[str, Any], base_fields: Dict[str, Any], config_components: Dict[str, JumpStartConfigComponent], - benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]], + is_hub_content=False, ): """Initializes a JumpStartMetadataConfig object from its json representation. Args: + config_name (str): Name of the config, + config (Dict[str, Any]): + Dictionary representation of the config. base_fields (Dict[str, Any]): - The default base fields that are used to construct the final resolved config. + The default base fields that are used to construct the resolved config. config_components (Dict[str, JumpStartConfigComponent]): The list of components that are used to construct the resolved config. - benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]): - The dictionary of benchmark metrics with name being the key. """ + if is_hub_content: + config = walk_and_apply_json(config, camel_to_snake) + base_fields = walk_and_apply_json(base_fields, camel_to_snake) self.base_fields = base_fields self.config_components: Dict[str, JumpStartConfigComponent] = config_components - self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics + self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = ( + { + stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] + for stat_name, stats in config.get("benchmark_metrics").items() + } + if config and config.get("benchmark_metrics") + else None + ) + self.acceleration_configs = config.get("acceleration_configs") self.resolved_metadata_config: Optional[Dict[str, Any]] = None + self.config_name: Optional[str] = config_name + self.default_inference_config: Optional[str] = config.get("default_inference_config") + self.default_incremental_training_config: Optional[str] = config.get( + "default_incremental_training_config" + ) + self.supported_inference_configs: Optional[List[str]] = config.get( + "supported_inference_configs" + ) + self.supported_incremental_training_configs: Optional[List[str]] = config.get( + "supported_incremental_training_configs" + ) def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataConfig object.""" @@ -1122,6 +1647,12 @@ def resolved_config(self) -> Dict[str, Any]: deepcopy(component.to_json()), component.OVERRIDING_DENY_LIST, ) + + # Remove environment variables from resolved config if using model packages + hosting_model_pacakge_arns = resolved_config.get("hosting_model_package_arns") + if hosting_model_pacakge_arns is not None and hosting_model_pacakge_arns != {}: + resolved_config["inference_environment_variables"] = [] + self.resolved_metadata_config = resolved_config return resolved_config @@ -1164,6 +1695,8 @@ def get_top_config_from_ranking( ) -> Optional[JumpStartMetadataConfig]: """Gets the best the config based on config ranking. + Fallback to use the ordering in config names if + ranking is not available. Args: ranking_name (str): The ranking name that config priority is based on. @@ -1171,13 +1704,8 @@ def get_top_config_from_ranking( The instance type which the config selection is based on. Raises: - ValueError: If the config exists but missing config ranking. NotImplementedError: If the scope is unrecognized. """ - if self.configs and ( - not self.config_rankings or not self.config_rankings.get(ranking_name) - ): - raise ValueError(f"Config exists but missing config ranking {ranking_name}.") if self.scope == JumpStartScriptScope.INFERENCE: instance_type_attribute = "supported_inference_instance_types" @@ -1186,8 +1714,14 @@ def get_top_config_from_ranking( else: raise NotImplementedError(f"Unknown script scope {self.scope}") - rankings = self.config_rankings.get(ranking_name) - for config_name in rankings.rankings: + if self.configs and ( + not self.config_rankings or not self.config_rankings.get(ranking_name) + ): + ranked_config_names = sorted(list(self.configs.keys())) + else: + rankings = self.config_rankings.get(ranking_name) + ranked_config_names = rankings.rankings + for config_name in ranked_config_names: resolved_config = self.configs[config_name].resolved_config if instance_type and instance_type not in getattr( resolved_config, instance_type_attribute @@ -1212,13 +1746,14 @@ class JumpStartModelSpecs(JumpStartMetadataBaseFields): __slots__ = JumpStartMetadataBaseFields.__slots__ + slots - def __init__(self, spec: Dict[str, Any]): + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartModelSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. + is_hub_content (Optional[bool]): Whether the model is from a private hub. """ - super().__init__(spec) + super().__init__(spec, is_hub_content) self.from_json(spec) if self.inference_configs and self.inference_configs.get_top_config_from_ranking(): super().from_json(self.inference_configs.get_top_config_from_ranking().resolved_config) @@ -1230,6 +1765,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of spec. """ super().from_json(json_obj) + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.inference_config_components: Optional[Dict[str, JumpStartConfigComponent]] = ( { component_name: JumpStartConfigComponent(component_name, component) @@ -1240,38 +1777,50 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: ) self.inference_config_rankings: Optional[Dict[str, JumpStartConfigRanking]] = ( { - alias: JumpStartConfigRanking(ranking) + alias: JumpStartConfigRanking(ranking, is_hub_content=self._is_hub_content) for alias, ranking in json_obj["inference_config_rankings"].items() } if json_obj.get("inference_config_rankings") else None ) - inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( - { - alias: JumpStartMetadataConfig( - json_obj, - ( - { - component_name: self.inference_config_components.get(component_name) - for component_name in config.get("component_names") - } - if config and config.get("component_names") - else None - ), - ( - { - stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] - for stat_name, stats in config.get("benchmark_metrics").items() - } - if config and config.get("benchmark_metrics") - else None - ), - ) - for alias, config in json_obj["inference_configs"].items() - } - if json_obj.get("inference_configs") - else None - ) + + if self._is_hub_content: + inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( + { + alias: JumpStartMetadataConfig( + alias, + config, + json_obj, + config.config_components, + is_hub_content=self._is_hub_content, + ) + for alias, config in json_obj["inference_configs"]["configs"].items() + } + if json_obj.get("inference_configs") + else None + ) + else: + inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( + { + alias: JumpStartMetadataConfig( + alias, + config, + json_obj, + ( + { + component_name: self.inference_config_components.get(component_name) + for component_name in config.get("component_names") + } + if config and config.get("component_names") + else None + ), + ) + for alias, config in json_obj["inference_configs"].items() + } + if json_obj.get("inference_configs") + else None + ) + self.inference_configs: Optional[JumpStartMetadataConfigs] = ( JumpStartMetadataConfigs( inference_configs_dict, @@ -1298,32 +1847,45 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if json_obj.get("training_config_rankings") else None ) - training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( - { - alias: JumpStartMetadataConfig( - json_obj, - ( - { - component_name: self.training_config_components.get(component_name) - for component_name in config.get("component_names") - } - if config and config.get("component_names") - else None - ), - ( - { - stat_name: [JumpStartBenchmarkStat(stat) for stat in stats] - for stat_name, stats in config.get("benchmark_metrics").items() - } - if config and config.get("benchmark_metrics") - else None - ), - ) - for alias, config in json_obj["training_configs"].items() - } - if json_obj.get("training_configs") - else None - ) + + if self._is_hub_content: + training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( + { + alias: JumpStartMetadataConfig( + alias, + config, + json_obj, + config.config_components, + is_hub_content=self._is_hub_content, + ) + for alias, config in json_obj["training_configs"]["configs"].items() + } + if json_obj.get("training_configs") + else None + ) + else: + training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = ( + { + alias: JumpStartMetadataConfig( + alias, + config, + json_obj, + ( + { + component_name: self.training_config_components.get( + component_name + ) + for component_name in config.get("component_names") + } + if config and config.get("component_names") + else None + ), + ) + for alias, config in json_obj["training_configs"].items() + } + if json_obj.get("training_configs") + else None + ) self.training_configs: Optional[JumpStartMetadataConfigs] = ( JumpStartMetadataConfigs( @@ -1378,12 +1940,20 @@ def use_inference_script_uri(self) -> bool: def use_training_model_artifact(self) -> bool: """Returns True if the model should use a model uri when kicking off training job.""" - # gated model never use training model artifact - if self.gated_bucket: + # old models with this environment variable present don't use model channel + if any( + self.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( + instance_type + ) + for instance_type in self.supported_training_instance_types + ): return False - # otherwise, return true is a training model package is not set - return len(self.training_model_package_artifact_uris or {}) == 0 + # even older models with training model package artifact uris present also don't use model channel + if len(self.training_model_package_artifact_uris or {}) > 0: + return False + + return getattr(self, "training_artifact_key", None) is not None def is_gated_model(self) -> bool: """Returns True if the model has a EULA key or the model bucket is gated.""" @@ -1393,6 +1963,21 @@ def supports_incremental_training(self) -> bool: """Returns True if the model supports incremental training.""" return self.incremental_training_supported + def get_speculative_decoding_s3_data_sources(self) -> List[JumpStartModelDataSource]: + """Returns data sources for speculative decoding.""" + if not self.hosting_additional_data_sources: + return [] + return self.hosting_additional_data_sources.speculative_decoding or [] + + def get_additional_s3_data_sources(self) -> List[JumpStartAdditionalDataSources]: + """Returns a list of the additional S3 data sources for use by the model.""" + additional_data_sources = [] + if self.hosting_additional_data_sources: + for data_source in self.hosting_additional_data_sources.to_json(): + data_sources = getattr(self.hosting_additional_data_sources, data_source) or [] + additional_data_sources.extend(data_sources) + return additional_data_sources + class JumpStartVersionedModelId(JumpStartDataHolderType): """Data class for versioned model IDs.""" @@ -1414,27 +1999,83 @@ def __init__( self.version = version -class JumpStartCachedS3ContentKey(JumpStartDataHolderType): - """Data class for the s3 cached content keys.""" +class JumpStartCachedContentKey(JumpStartDataHolderType): + """Data class for the cached content keys.""" - __slots__ = ["file_type", "s3_key"] + __slots__ = ["data_type", "id_info"] def __init__( self, - file_type: JumpStartS3FileType, - s3_key: str, + data_type: JumpStartContentDataType, + id_info: str, ) -> None: - """Instantiates JumpStartCachedS3ContentKey object. + """Instantiates JumpStartCachedContentKey object. Args: - file_type (JumpStartS3FileType): JumpStart file type. - s3_key (str): object key in s3. + data_type (JumpStartContentDataType): JumpStart content data type. + id_info (str): if S3Content, object key in s3. if HubContent, hub content arn. """ - self.file_type = file_type - self.s3_key = s3_key + self.data_type = data_type + self.id_info = id_info -class JumpStartCachedS3ContentValue(JumpStartDataHolderType): +class HubArnExtractedInfo(JumpStartDataHolderType): + """Data class for info extracted from Hub arn.""" + + __slots__ = [ + "partition", + "region", + "account_id", + "hub_name", + "hub_content_type", + "hub_content_name", + "hub_content_version", + ] + + def __init__( + self, + partition: str, + region: str, + account_id: str, + hub_name: str, + hub_content_type: Optional[str] = None, + hub_content_name: Optional[str] = None, + hub_content_version: Optional[str] = None, + ) -> None: + """Instantiates HubArnExtractedInfo object.""" + + self.partition = partition + self.region = region + self.account_id = account_id + self.hub_name = hub_name + self.hub_content_name = hub_content_name + self.hub_content_type = hub_content_type + self.hub_content_version = hub_content_version + + @staticmethod + def extract_region_from_arn(arn: str) -> Optional[str]: + """Extracts hub_name, content_name, and content_version from a HubContentArn""" + + HUB_CONTENT_ARN_REGEX = ( + r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" + ) + HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" + + match = re.match(HUB_CONTENT_ARN_REGEX, arn) + hub_region = None + if match: + hub_region = match.group(2) + return hub_region + + match = re.match(HUB_ARN_REGEX, arn) + if match: + hub_region = match.group(2) + return hub_region + + return hub_region + + +class JumpStartCachedContentValue(JumpStartDataHolderType): """Data class for the s3 cached content values.""" __slots__ = ["formatted_content", "md5_hash"] @@ -1447,7 +2088,7 @@ def __init__( ], md5_hash: Optional[str] = None, ) -> None: - """Instantiates JumpStartCachedS3ContentValue object. + """Instantiates JumpStartCachedContentValue object. Args: formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader], @@ -1463,14 +2104,20 @@ def __init__( class JumpStartKwargs(JumpStartDataHolderType): """Data class for JumpStart object kwargs.""" + BASE_SERIALIZATION_EXCLUSION_SET: Set[str] = ["specs"] SERIALIZATION_EXCLUSION_SET: Set[str] = set() - def to_kwargs_dict(self): + def to_kwargs_dict(self, exclude_keys: bool = True): """Serializes object to dictionary to be used for kwargs for method arguments.""" kwargs_dict = {} for field in self.__slots__: - if field not in self.SERIALIZATION_EXCLUSION_SET: - att_value = getattr(self, field) + if ( + exclude_keys + and field + not in self.SERIALIZATION_EXCLUSION_SET.union(self.BASE_SERIALIZATION_EXCLUSION_SET) + or not exclude_keys + ): + att_value = getattr(self, field, None) if att_value is not None: kwargs_dict[field] = getattr(self, field) return kwargs_dict @@ -1482,6 +2129,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "instance_type", "tolerate_vulnerable_model", @@ -1507,31 +2155,40 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_package_arn", "training_instance_type", "resources", + "config_name", + "additional_model_data_sources", + "hub_content_type", + "model_reference_arn", + "specs", ] SERIALIZATION_EXCLUSION_SET = { "instance_type", "model_id", "model_version", + "hub_arn", "model_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", "region", "model_package_arn", "training_instance_type", + "config_name", + "hub_content_type", } def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, model_data: Optional[Union[str, Any, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, Any]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None, @@ -1550,11 +2207,14 @@ def __init__( model_package_arn: Optional[str] = None, training_instance_type: Optional[str] = None, resources: Optional[ResourceRequirements] = None, + config_name: Optional[str] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, ) -> None: """Instantiates JumpStartModelInitKwargs object.""" self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.instance_type = instance_type self.region = region @@ -1580,6 +2240,8 @@ def __init__( self.model_package_arn = model_package_arn self.training_instance_type = training_instance_type self.resources = resources + self.config_name = config_name + self.additional_model_data_sources = additional_model_data_sources class JumpStartModelDeployKwargs(JumpStartKwargs): @@ -1588,6 +2250,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "initial_instance_count", "instance_type", @@ -1596,6 +2259,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "deserializer", "accelerator_type", "endpoint_name", + "inference_component_name", "tags", "kms_key", "wait", @@ -1612,26 +2276,36 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "sagemaker_session", "training_instance_type", "accept_eula", + "model_reference_arn", "endpoint_logging", "resources", "endpoint_type", + "config_name", + "routing_config", + "specs", + "model_access_configs", + "inference_ami_version", ] SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", "model_type", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", "training_instance_type", + "config_name", + "model_access_configs", } def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, @@ -1640,6 +2314,7 @@ def __init__( deserializer: Optional[Any] = None, accelerator_type: Optional[str] = None, endpoint_name: Optional[str] = None, + inference_component_name: Optional[str] = None, tags: Optional[Tags] = None, kms_key: Optional[str] = None, wait: Optional[bool] = None, @@ -1656,14 +2331,20 @@ def __init__( sagemaker_session: Optional[Session] = None, training_instance_type: Optional[str] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None, endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, endpoint_type: Optional[EndpointType] = None, + config_name: Optional[str] = None, + routing_config: Optional[Dict[str, Any]] = None, + model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None, + inference_ami_version: Optional[str] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.initial_instance_count = initial_instance_count self.instance_type = instance_type @@ -1672,6 +2353,7 @@ def __init__( self.deserializer = deserializer self.accelerator_type = accelerator_type self.endpoint_name = endpoint_name + self.inference_component_name = inference_component_name self.tags = format_tags(tags) self.kms_key = kms_key self.wait = wait @@ -1688,9 +2370,14 @@ def __init__( self.sagemaker_session = sagemaker_session self.training_instance_type = training_instance_type self.accept_eula = accept_eula + self.model_reference_arn = model_reference_arn self.endpoint_logging = endpoint_logging self.resources = resources self.endpoint_type = endpoint_type + self.config_name = config_name + self.routing_config = routing_config + self.model_access_configs = model_access_configs + self.inference_ami_version = inference_ami_version class JumpStartEstimatorInitKwargs(JumpStartKwargs): @@ -1699,6 +2386,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "instance_type", "instance_count", @@ -1751,7 +2439,13 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "disable_output_compression", "enable_infra_check", "enable_remote_debug", + "config_name", "enable_session_tag_chaining", + "hub_content_type", + "model_reference_arn", + "specs", + "training_plan", + "instance_placement_config", ] SERIALIZATION_EXCLUSION_SET = { @@ -1760,13 +2454,17 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_id", "model_version", + "hub_arn", "model_type", + "hub_content_type", + "config_name", } def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1819,13 +2517,17 @@ def __init__( disable_output_compression: Optional[bool] = None, enable_infra_check: Optional[Union[bool, PipelineVariable]] = None, enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None, + config_name: Optional[str] = None, enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None, + training_plan: Optional[Union[str, PipelineVariable]] = None, + instance_placement_config: Optional[Dict] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" self.model_id = model_id self.model_version = model_version - self.model_type = (model_type,) + self.hub_arn = hub_arn + self.model_type = model_type self.instance_type = instance_type self.instance_count = instance_count self.region = region @@ -1879,7 +2581,10 @@ def __init__( self.disable_output_compression = disable_output_compression self.enable_infra_check = enable_infra_check self.enable_remote_debug = enable_remote_debug + self.config_name = config_name self.enable_session_tag_chaining = enable_session_tag_chaining + self.training_plan = training_plan + self.instance_placement_config = instance_placement_config class JumpStartEstimatorFitKwargs(JumpStartKwargs): @@ -1888,6 +2593,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "region", "inputs", @@ -1898,22 +2604,27 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", + "config_name", + "specs", ] SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "hub_arn", "model_type", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", + "config_name", } def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, @@ -1924,11 +2635,13 @@ def __init__( tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.region = region self.inputs = inputs @@ -1939,6 +2652,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.sagemaker_session = sagemaker_session + self.config_name = config_name class JumpStartEstimatorDeployKwargs(JumpStartKwargs): @@ -1947,6 +2661,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "instance_type", "initial_instance_count", "region", @@ -1984,6 +2699,8 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_name", "use_compiled_model", + "config_name", + "specs", ] SERIALIZATION_EXCLUSION_SET = { @@ -1992,13 +2709,16 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", + "config_name", } def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -2019,7 +2739,7 @@ def __init__( explainer_config: Optional[Any] = None, image_uri: Optional[Union[str, Any]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, Any]]] = None, model_name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None, @@ -2036,11 +2756,13 @@ def __init__( tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, use_compiled_model: bool = False, + config_name: Optional[str] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.instance_type = instance_type self.initial_instance_count = initial_instance_count self.region = region @@ -2078,6 +2800,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.use_compiled_model = use_compiled_model + self.config_name = config_name class JumpStartModelRegisterKwargs(JumpStartKwargs): @@ -2088,7 +2811,9 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "tolerate_deprecated_model", "region", "model_id", + "model_type", "model_version", + "hub_arn", "sagemaker_session", "content_types", "response_types", @@ -2112,6 +2837,11 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "data_input_configuration", "skip_model_validation", "source_uri", + "model_life_cycle", + "config_name", + "model_card", + "accept_eula", + "specs", ] SERIALIZATION_EXCLUSION_SET = { @@ -2120,14 +2850,18 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", + "config_name", } def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, + model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, sagemaker_session: Optional[Any] = None, @@ -2153,11 +2887,17 @@ def __init__( data_input_configuration: Optional[str] = None, skip_model_validation: Optional[str] = None, source_uri: Optional[str] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, + config_name: Optional[str] = None, + model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None, + accept_eula: Optional[bool] = None, ) -> None: """Instantiates JumpStartModelRegisterKwargs object.""" self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn + self.model_type = model_type self.region = region self.image_uri = image_uri self.sagemaker_session = sagemaker_session @@ -2185,3 +2925,128 @@ def __init__( self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation self.source_uri = source_uri + self.config_name = config_name + self.model_card = model_card + self.accept_eula = accept_eula + + +class BaseDeploymentConfigDataHolder(JumpStartDataHolderType): + """Base class for Deployment Config Data.""" + + def _convert_to_pascal_case(self, attr_name: str) -> str: + """Converts a snake_case attribute name into a camelCased string. + + Args: + attr_name (str): The snake_case attribute name. + Returns: + str: The PascalCased attribute name. + """ + return attr_name.replace("_", " ").title().replace(" ", "") + + def to_json(self) -> Dict[str, Any]: + """Represents ``This`` object as JSON.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + att = self._convert_to_pascal_case(att) + json_obj[att] = self._val_to_json(cur_val) + return json_obj + + def _val_to_json(self, val: Any) -> Any: + """Converts the given value to JSON. + + Args: + val (Any): The value to convert. + Returns: + Any: The converted json value. + """ + if issubclass(type(val), JumpStartDataHolderType): + if isinstance(val, JumpStartBenchmarkStat): + val.name = val.name.replace("_", " ").title() + return val.to_json() + if isinstance(val, list): + list_obj = [] + for obj in val: + list_obj.append(self._val_to_json(obj)) + return list_obj + if isinstance(val, dict): + dict_obj = {} + for k, v in val.items(): + if isinstance(v, JumpStartDataHolderType): + dict_obj[self._convert_to_pascal_case(k)] = self._val_to_json(v) + else: + dict_obj[k] = self._val_to_json(v) + return dict_obj + return val + + +class DeploymentArgs(BaseDeploymentConfigDataHolder): + """Dataclass representing a Deployment Args.""" + + __slots__ = [ + "image_uri", + "model_data", + "model_package_arn", + "environment", + "instance_type", + "compute_resource_requirements", + "model_data_download_timeout", + "container_startup_health_check_timeout", + "additional_data_sources", + ] + + def __init__( + self, + init_kwargs: Optional[JumpStartModelInitKwargs] = None, + deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None, + resolved_config: Optional[Dict[str, Any]] = None, + ): + """Instantiates DeploymentArgs object.""" + if init_kwargs is not None: + self.image_uri = init_kwargs.image_uri + self.model_data = init_kwargs.model_data + self.model_package_arn = init_kwargs.model_package_arn + self.instance_type = init_kwargs.instance_type + self.environment = init_kwargs.env + if init_kwargs.resources is not None: + self.compute_resource_requirements = ( + init_kwargs.resources.get_compute_resource_requirements() + ) + if deploy_kwargs is not None: + self.model_data_download_timeout = deploy_kwargs.model_data_download_timeout + self.container_startup_health_check_timeout = ( + deploy_kwargs.container_startup_health_check_timeout + ) + if resolved_config is not None: + self.default_instance_type = resolved_config.get("default_inference_instance_type") + self.supported_instance_types = resolved_config.get( + "supported_inference_instance_types" + ) + self.additional_data_sources = resolved_config.get("hosting_additional_data_sources") + + +class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): + """Dataclass representing a Deployment Config Metadata""" + + __slots__ = [ + "deployment_config_name", + "deployment_args", + "acceleration_configs", + "benchmark_metrics", + ] + + def __init__( + self, + config_name: Optional[str] = None, + metadata_config: Optional[JumpStartMetadataConfig] = None, + init_kwargs: Optional[JumpStartModelInitKwargs] = None, + deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None, + ): + """Instantiates DeploymentConfigMetadata object.""" + self.deployment_config_name = config_name + self.deployment_args = DeploymentArgs( + init_kwargs, deploy_kwargs, metadata_config.resolved_config + ) + self.benchmark_metrics = metadata_config.benchmark_metrics + self.acceleration_configs = metadata_config.acceleration_configs diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 63cfac0939..15f9e9b52e 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -12,12 +12,18 @@ # language governing permissions and limitations under the License. """This module contains utilities related to SageMaker JumpStart.""" from __future__ import absolute_import + +from copy import copy import logging import os +from functools import lru_cache, wraps from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 -from packaging.version import Version +from botocore.exceptions import ClientError +from packaging.version import Version, InvalidVersion +import botocore +from sagemaker_core.shapes import ModelAccessConfig import sagemaker from sagemaker.config.config_schema import ( MODEL_ENABLE_NETWORK_ISOLATION_PATH, @@ -29,6 +35,7 @@ from sagemaker.jumpstart import constants, enums from sagemaker.jumpstart import accessors +from sagemaker.jumpstart.hub.parser_utils import camel_to_snake, snake_to_upper_camel from sagemaker.s3 import parse_s3_url from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, @@ -41,11 +48,19 @@ JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, + DeploymentConfigMetadata, ) from sagemaker.session import Session from sagemaker.config import load_sagemaker_config -from sagemaker.utils import resolve_value_from_config, TagsDict +from sagemaker.utils import ( + resolve_value_from_config, + TagsDict, + get_instance_rate_per_hour, + get_domain_for_region, + camel_case_to_pascal_case, +) from sagemaker.workflow import is_pipeline_variable +from sagemaker.user_agent import get_user_agent_extra_suffix def get_jumpstart_launched_regions_message() -> str: @@ -123,7 +138,7 @@ def get_jumpstart_gated_content_bucket( def get_jumpstart_content_bucket( region: str = constants.JUMPSTART_DEFAULT_REGION_NAME, ) -> str: - """Returns regionalized content bucket name for JumpStart. + """Returns the regionalized content bucket name for JumpStart. Raises: ValueError: If JumpStart is not launched in ``region``. @@ -150,7 +165,7 @@ def get_jumpstart_content_bucket( except KeyError: formatted_launched_regions_str = get_jumpstart_launched_regions_message() raise ValueError( - f"Unable to get content bucket for JumpStart in {region} region. " + f"Unable to get content bucket for Neo in {region} region. " f"{formatted_launched_regions_str}" ) @@ -164,6 +179,34 @@ def get_jumpstart_content_bucket( return bucket_to_return +def get_neo_content_bucket( + region: str = constants.NEO_DEFAULT_REGION_NAME, +) -> str: + """Returns the regionalized S3 bucket name for Neo service. + + Raises: + ValueError: If Neo is not launched in ``region``. + """ + + bucket_to_return: Optional[str] = None + if ( + constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE in os.environ + and len(os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE]) > 0 + ): + bucket_to_return = os.environ[constants.ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE] + info_log = f"Using Neo bucket override: '{bucket_to_return}'" + constants.JUMPSTART_LOGGER.info(info_log) + else: + try: + bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[ + region + ].neo_content_bucket + except KeyError: + raise ValueError(f"Unable to get content bucket for Neo in {region} region.") + + return bucket_to_return + + def get_formatted_manifest( manifest: List[Dict], ) -> Dict[JumpStartVersionedModelId, JumpStartModelHeader]: @@ -318,6 +361,8 @@ def add_single_jumpstart_tag( tag_key_in_array(enums.JumpStartTag.MODEL_ID, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_VERSION, curr_tags) or tag_key_in_array(enums.JumpStartTag.MODEL_TYPE, curr_tags) + or tag_key_in_array(enums.JumpStartTag.INFERENCE_CONFIG_NAME, curr_tags) + or tag_key_in_array(enums.JumpStartTag.TRAINING_CONFIG_NAME, curr_tags) ) if is_uri else False @@ -348,11 +393,13 @@ def get_jumpstart_base_name_if_jumpstart_model( return None -def add_jumpstart_model_id_version_tags( +def add_jumpstart_model_info_tags( tags: Optional[List[TagsDict]], model_id: str, model_version: str, model_type: Optional[enums.JumpStartModelType] = None, + config_name: Optional[str] = None, + scope: enums.JumpStartScriptScope = None, ) -> List[TagsDict]: """Add custom model ID and version tags to JumpStart related resources.""" if model_id is None or model_version is None: @@ -376,6 +423,50 @@ def add_jumpstart_model_id_version_tags( tags, is_uri=False, ) + if config_name and scope == enums.JumpStartScriptScope.INFERENCE: + tags = add_single_jumpstart_tag( + config_name, + enums.JumpStartTag.INFERENCE_CONFIG_NAME, + tags, + is_uri=False, + ) + if config_name and scope == enums.JumpStartScriptScope.TRAINING: + tags = add_single_jumpstart_tag( + config_name, + enums.JumpStartTag.TRAINING_CONFIG_NAME, + tags, + is_uri=False, + ) + return tags + + +def add_hub_content_arn_tags( + tags: Optional[List[TagsDict]], + hub_content_arn: str, +) -> Optional[List[TagsDict]]: + """Adds custom Hub arn tag to JumpStart related resources.""" + + tags = add_single_jumpstart_tag( + hub_content_arn, + enums.JumpStartTag.HUB_CONTENT_ARN, + tags, + is_uri=False, + ) + return tags + + +def add_bedrock_store_tags( + tags: Optional[List[TagsDict]], + compatibility: str, +) -> Optional[List[TagsDict]]: + """Adds custom Hub arn tag to JumpStart related resources.""" + + tags = add_single_jumpstart_tag( + compatibility, + enums.JumpStartTag.BEDROCK, + tags, + is_uri=False, + ) return tags @@ -482,11 +573,18 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str: """Returns EULA message to display if one is available, else empty string.""" if model_specs.hosting_eula_key is None: return "" + return get_formatted_eula_message_template( + model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key + ) + + +def get_formatted_eula_message_template(model_id: str, region: str, hosting_eula_key: str) -> str: + """Returns a formatted EULA message.""" return ( - f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). " + f"Model '{model_id}' requires accepting end-user license agreement (EULA). " f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}." - f"amazonaws.com{'.cn' if region.startswith('cn-') else ''}" - f"/{model_specs.hosting_eula_key} for terms of use." + f"{get_domain_for_region(region)}" + f"/{hosting_eula_key} for terms of use." ) @@ -543,10 +641,12 @@ def verify_model_region_and_return_specs( version: Optional[str], scope: Optional[str], region: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -558,6 +658,8 @@ def verify_model_region_and_return_specs( scope (Optional[str]): scope of the JumpStart model to verify. region (Optional[str]): region of the JumpStart model to verify and obtains specs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -569,6 +671,7 @@ def verify_model_region_and_return_specs( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Raises: NotImplementedError: If the scope is not supported. @@ -597,9 +700,11 @@ def verify_model_region_and_return_specs( model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore region=region, model_id=model_id, + hub_arn=hub_arn, version=version, s3_client=sagemaker_session.s3_client, model_type=model_type, + sagemaker_session=sagemaker_session, ) if ( @@ -634,6 +739,9 @@ def verify_model_region_and_return_specs( scope=constants.JumpStartScriptScope.TRAINING, ) + if model_specs and config_name: + model_specs.set_config(config_name, scope) + return model_specs @@ -760,6 +868,7 @@ def validate_model_id_and_get_type( model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + hub_arn: Optional[str] = None, ) -> Optional[enums.JumpStartModelType]: """Returns model type if the model ID is supported for the given script. @@ -771,6 +880,17 @@ def validate_model_id_and_get_type( return None if not isinstance(model_id, str): return None + if hub_arn: + model_types = _validate_hub_service_model_id_and_get_type( + model_id=model_id, + hub_arn=hub_arn, + region=region, + model_version=model_version, + sagemaker_session=sagemaker_session, + ) + return ( + model_types[0] if model_types else None + ) # Currently this function only supports one model type s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME @@ -795,52 +915,111 @@ def validate_model_id_and_get_type( return None -def get_jumpstart_model_id_version_from_resource_arn( - resource_arn: str, - sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Tuple[Optional[str], Optional[str]]: - """Returns the JumpStart model ID and version if in resource tags. +def _validate_hub_service_model_id_and_get_type( + model_id: Optional[str], + hub_arn: str, + region: Optional[str] = None, + model_version: Optional[str] = None, + sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> List[enums.JumpStartModelType]: + """Returns a list of JumpStartModelType based off the HubContent. - Returns 'None' if model ID or version cannot be inferred from tags. + Only returns valid JumpStartModelType. Returns an empty array if none are found. """ + hub_model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + region=region, + model_id=model_id, + version=model_version, + hub_arn=hub_arn, + sagemaker_session=sagemaker_session, + ) - list_tags_result = sagemaker_session.list_tags(resource_arn) + hub_content_model_types = [] + model_types_field: Optional[List[str]] = getattr(hub_model_specs, "model_types", []) + model_types = model_types_field if model_types_field else [] + for model_type in model_types: + try: + hub_content_model_types.append(enums.JumpStartModelType[model_type]) + except ValueError: + continue - model_id: Optional[str] = None - model_version: Optional[str] = None + return hub_content_model_types - model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS] - model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS] - for model_id_key in model_id_keys: - try: - model_id_from_tag = get_tag_value(model_id_key, list_tags_result) - except KeyError: - continue - if model_id_from_tag is not None: - if model_id is not None and model_id_from_tag != model_id: - constants.JUMPSTART_LOGGER.warning( - "Found multiple model ID tags on the following resource: %s", resource_arn - ) - model_id = None - break - model_id = model_id_from_tag +def _extract_value_from_list_of_tags( + tag_keys: List[str], + list_tags_result: List[str], + resource_name: str, + resource_arn: str, +): + """Extracts value from list of tags with check of duplicate tags. - for model_version_key in model_version_keys: + Returns None if no value is found. + """ + resolved_value = None + for tag_key in tag_keys: try: - model_version_from_tag = get_tag_value(model_version_key, list_tags_result) + value_from_tag = get_tag_value(tag_key, list_tags_result) except KeyError: continue - if model_version_from_tag is not None: - if model_version is not None and model_version_from_tag != model_version: + if value_from_tag is not None: + if resolved_value is not None and value_from_tag != resolved_value: constants.JUMPSTART_LOGGER.warning( - "Found multiple model version tags on the following resource: %s", resource_arn + "Found multiple %s tags on the following resource: %s", + resource_name, + resource_arn, ) - model_version = None + resolved_value = None break - model_version = model_version_from_tag + resolved_value = value_from_tag + return resolved_value + + +def get_jumpstart_model_info_from_resource_arn( + resource_arn: str, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + """Returns the JumpStart model ID, version and config name if in resource tags. + + Returns 'None' if model ID or version or config name cannot be inferred from tags. + """ + + list_tags_result = sagemaker_session.list_tags(resource_arn) + + model_id_keys = [enums.JumpStartTag.MODEL_ID, *constants.EXTRA_MODEL_ID_TAGS] + model_version_keys = [enums.JumpStartTag.MODEL_VERSION, *constants.EXTRA_MODEL_VERSION_TAGS] + inference_config_name_keys = [enums.JumpStartTag.INFERENCE_CONFIG_NAME] + training_config_name_keys = [enums.JumpStartTag.TRAINING_CONFIG_NAME] + + model_id: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_id_keys, + list_tags_result=list_tags_result, + resource_name="model ID", + resource_arn=resource_arn, + ) + + model_version: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=model_version_keys, + list_tags_result=list_tags_result, + resource_name="model version", + resource_arn=resource_arn, + ) + + inference_config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=inference_config_name_keys, + list_tags_result=list_tags_result, + resource_name="inference config name", + resource_arn=resource_arn, + ) - return model_id, model_version + training_config_name: Optional[str] = _extract_value_from_list_of_tags( + tag_keys=training_config_name_keys, + list_tags_result=list_tags_result, + resource_name="training config name", + resource_arn=resource_arn, + ) + + return model_id, model_version, inference_config_name, training_config_name def get_region_fallback( @@ -890,7 +1069,11 @@ def get_config_names( scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> List[str]: - """Returns a list of config names for the given model ID and region.""" + """Returns a list of config names for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -905,7 +1088,7 @@ def get_config_names( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") return list(metadata_configs.configs.keys()) if metadata_configs else [] @@ -915,15 +1098,21 @@ def get_benchmark_stats( model_id: str, model_version: str, config_names: Optional[List[str]] = None, + hub_arn: Optional[str] = None, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, ) -> Dict[str, List[JumpStartBenchmarkStat]]: - """Returns benchmark stats for the given model ID and region.""" + """Returns benchmark stats for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, version=model_version, + hub_arn=hub_arn, sagemaker_session=sagemaker_session, scope=scope, model_type=model_type, @@ -934,7 +1123,7 @@ def get_benchmark_stats( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") if not config_names: config_names = metadata_configs.configs.keys() if metadata_configs else [] @@ -942,7 +1131,7 @@ def get_benchmark_stats( benchmark_stats = {} for config_name in config_names: if config_name not in metadata_configs.configs: - raise ValueError(f"Unknown config name: '{config_name}'") + raise ValueError(f"Unknown config name: {config_name}") benchmark_stats[config_name] = metadata_configs.configs.get(config_name).benchmark_metrics return benchmark_stats @@ -956,8 +1145,13 @@ def get_jumpstart_configs( sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, -) -> Dict[str, List[JumpStartMetadataConfig]]: - """Returns metadata configs for the given model ID and region.""" + hub_arn: Optional[str] = None, +) -> Dict[str, JumpStartMetadataConfig]: + """Returns metadata configs for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ model_specs = verify_model_region_and_return_specs( region=region, model_id=model_id, @@ -965,6 +1159,7 @@ def get_jumpstart_configs( sagemaker_session=sagemaker_session, scope=scope, model_type=model_type, + hub_arn=hub_arn, ) if scope == enums.JumpStartScriptScope.INFERENCE: @@ -972,13 +1167,535 @@ def get_jumpstart_configs( elif scope == enums.JumpStartScriptScope.TRAINING: metadata_configs = model_specs.training_configs else: - raise ValueError(f"Unknown script scope {scope}.") + raise ValueError(f"Unknown script scope: {scope}.") if not config_names: - config_names = metadata_configs.configs.keys() if metadata_configs else [] + config_names = ( + metadata_configs.config_rankings.get("overall").rankings if metadata_configs else [] + ) + if hub_arn: + return ( + { + config_name: metadata_configs.configs[ + camel_to_snake(snake_to_upper_camel(config_name)) + ] + for config_name in config_names + } + if metadata_configs + else {} + ) return ( {config_name: metadata_configs.configs[config_name] for config_name in config_names} if metadata_configs else {} ) + + +def get_jumpstart_user_agent_extra_suffix( + model_id: Optional[str], + model_version: Optional[str], + config_name: Optional[str], + is_hub_content: Optional[bool], +) -> str: + """Returns the model-specific user agent string to be added to requests.""" + sagemaker_python_sdk_headers = get_user_agent_extra_suffix() + jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}" + config_specific_suffix = f"md/js_config#{config_name}" + hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}" + + if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None): + headers = sagemaker_python_sdk_headers + elif is_hub_content is True: + if model_id is None and model_version is None: + headers = f"{sagemaker_python_sdk_headers} {hub_specific_suffix}" + else: + headers = ( + f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}" + ) + else: + headers = f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}" + + if config_name: + headers = f"{headers} {config_specific_suffix}" + + return headers + + +def get_top_ranked_config_name( + region: str, + model_id: str, + model_version: str, + sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, + model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, + tolerate_deprecated_model: bool = False, + tolerate_vulnerable_model: bool = False, + hub_arn: Optional[str] = None, + ranking_name: enums.JumpStartConfigRankingName = enums.JumpStartConfigRankingName.DEFAULT, +) -> Optional[str]: + """Returns the top ranked config name for the given model ID and region. + + Raises: + ValueError: If the script scope is not supported by JumpStart. + """ + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + scope=scope, + region=region, + hub_arn=hub_arn, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, + model_type=model_type, + ) + + if scope == enums.JumpStartScriptScope.INFERENCE: + return ( + model_specs.inference_configs.get_top_config_from_ranking( + ranking_name=ranking_name + ).config_name + if model_specs.inference_configs + else None + ) + if scope == enums.JumpStartScriptScope.TRAINING: + return ( + model_specs.training_configs.get_top_config_from_ranking( + ranking_name=ranking_name + ).config_name + if model_specs.training_configs + else None + ) + raise ValueError(f"Unsupported script scope: {scope}.") + + +def get_default_jumpstart_session_with_user_agent_suffix( + model_id: Optional[str] = None, + model_version: Optional[str] = None, + config_name: Optional[str] = None, + is_hub_content: Optional[bool] = False, +) -> Session: + """Returns default JumpStart SageMaker Session with model-specific user agent suffix.""" + botocore_session = botocore.session.get_session() + botocore_config = botocore.config.Config( + user_agent_extra=get_jumpstart_user_agent_extra_suffix( + model_id=model_id, + model_version=model_version, + config_name=config_name, + is_hub_content=is_hub_content, + ), + ) + botocore_session.set_default_client_config(botocore_config) + # shallow copy to not affect default session constant + session = copy(constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION) + session.boto_session = boto3.Session( + region_name=constants.JUMPSTART_DEFAULT_REGION_NAME, botocore_session=botocore_session + ) + session.sagemaker_client = boto3.client( + "sagemaker", region_name=constants.JUMPSTART_DEFAULT_REGION_NAME, config=botocore_config + ) + session.sagemaker_runtime_client = boto3.client( + "sagemaker-runtime", + region_name=constants.JUMPSTART_DEFAULT_REGION_NAME, + config=botocore_config, + ) + return session + + +def add_instance_rate_stats_to_benchmark_metrics( + region: str, + benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]], +) -> Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]: + """Adds instance types metric stats to the given benchmark_metrics dict. + + Args: + region (str): AWS region. + benchmark_metrics (Optional[Dict[str, List[JumpStartBenchmarkStat]]]): + Returns: + Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]: + Contains Error and metrics. + """ + if not benchmark_metrics: + return None + + err_message = None + final_benchmark_metrics = {} + for instance_type, benchmark_metric_stats in benchmark_metrics.items(): + instance_type = instance_type if instance_type.startswith("ml.") else f"ml.{instance_type}" + + if not has_instance_rate_stat(benchmark_metric_stats) and not err_message: + try: + instance_type_rate = get_instance_rate_per_hour( + instance_type=instance_type, region=region + ) + + if not benchmark_metric_stats: + benchmark_metric_stats = [] + benchmark_metric_stats.append( + JumpStartBenchmarkStat({"concurrency": None, **instance_type_rate}) + ) + + final_benchmark_metrics[instance_type] = benchmark_metric_stats + except ClientError as e: + final_benchmark_metrics[instance_type] = benchmark_metric_stats + err_message = e.response["Error"] + except Exception: # pylint: disable=W0703 + final_benchmark_metrics[instance_type] = benchmark_metric_stats + else: + final_benchmark_metrics[instance_type] = benchmark_metric_stats + + return err_message, final_benchmark_metrics + + +def has_instance_rate_stat(benchmark_metric_stats: Optional[List[JumpStartBenchmarkStat]]) -> bool: + """Determines whether a benchmark metric stats contains instance rate metric stat. + + Args: + benchmark_metric_stats (Optional[List[JumpStartBenchmarkStat]]): + List of benchmark metric stats. + Returns: + bool: Whether the benchmark metric stats contains instance rate metric stat. + """ + if benchmark_metric_stats is None: + return True + for benchmark_metric_stat in benchmark_metric_stats: + if benchmark_metric_stat.name.lower() == "instance rate": + return True + return False + + +def get_metrics_from_deployment_configs( + deployment_configs: Optional[List[DeploymentConfigMetadata]], +) -> Dict[str, List[str]]: + """Extracts benchmark metrics from deployment configs metadata. + + Args: + deployment_configs (Optional[List[DeploymentConfigMetadata]]): + List of deployment configs metadata. + Returns: + Dict[str, List[str]]: Deployment configs bench metrics dict. + """ + if not deployment_configs: + return {} + + data = {"Instance Type": [], "Config Name": [], "Concurrent Users": []} + instance_rate_data = {} + for index, deployment_config in enumerate(deployment_configs): + benchmark_metrics = deployment_config.benchmark_metrics + if not deployment_config.deployment_args or not benchmark_metrics: + continue + + for current_instance_type, current_instance_type_metrics in benchmark_metrics.items(): + instance_type_rate, concurrent_users = _normalize_benchmark_metrics( + current_instance_type_metrics + ) + + for concurrent_user, metrics in concurrent_users.items(): + instance_type_to_display = ( + f"{current_instance_type} (Default)" + if index == 0 + and concurrent_user + and int(concurrent_user) == 1 + and current_instance_type + == deployment_config.deployment_args.default_instance_type + else current_instance_type + ) + + data["Config Name"].append(deployment_config.deployment_config_name) + data["Instance Type"].append(instance_type_to_display) + data["Concurrent Users"].append(concurrent_user) + + if instance_type_rate: + instance_rate_column_name = ( + f"{instance_type_rate.name} ({instance_type_rate.unit})" + ) + instance_rate_data[instance_rate_column_name] = instance_rate_data.get( + instance_rate_column_name, [] + ) + instance_rate_data[instance_rate_column_name].append(instance_type_rate.value) + + for metric in metrics: + column_name = _normalize_benchmark_metric_column_name(metric.name, metric.unit) + data[column_name] = data.get(column_name, []) + data[column_name].append(metric.value) + + data = {**data, **instance_rate_data} + return data + + +def _normalize_benchmark_metric_column_name(name: str, unit: str) -> str: + """Normalizes benchmark metric column name. + + Args: + name (str): Name of the metric. + unit (str): Unit of the metric. + Returns: + str: Normalized metric column name. + """ + if "latency" in name.lower(): + name = f"Latency, TTFT (P50 in {unit.lower()})" + elif "throughput" in name.lower(): + name = f"Throughput (P50 in {unit.lower()}/user)" + return name + + +def _normalize_benchmark_metrics( + benchmark_metric_stats: List[JumpStartBenchmarkStat], +) -> Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]: + """Normalizes benchmark metrics dict. + + Args: + benchmark_metric_stats (List[JumpStartBenchmarkStat]): + List of benchmark metrics stats. + Returns: + Tuple[JumpStartBenchmarkStat, Dict[str, List[JumpStartBenchmarkStat]]]: + Normalized benchmark metrics dict. + """ + instance_type_rate = None + concurrent_users = {} + for current_instance_type_metric in benchmark_metric_stats: + if "instance rate" in current_instance_type_metric.name.lower(): + instance_type_rate = current_instance_type_metric + elif current_instance_type_metric.concurrency not in concurrent_users: + concurrent_users[current_instance_type_metric.concurrency] = [ + current_instance_type_metric + ] + else: + concurrent_users[current_instance_type_metric.concurrency].append( + current_instance_type_metric + ) + + return instance_type_rate, concurrent_users + + +def deployment_config_response_data( + deployment_configs: Optional[List[DeploymentConfigMetadata]], +) -> List[Dict[str, Any]]: + """Deployment config api response data. + + Args: + deployment_configs (Optional[List[DeploymentConfigMetadata]]): + List of deployment configs metadata. + Returns: + List[Dict[str, Any]]: List of deployment config api response data. + """ + configs = [] + if not deployment_configs: + return configs + + for deployment_config in deployment_configs: + deployment_config_json = deployment_config.to_json() + benchmark_metrics = deployment_config_json.get("BenchmarkMetrics") + if benchmark_metrics and deployment_config.deployment_args: + deployment_config_json["BenchmarkMetrics"] = { + deployment_config.deployment_args.instance_type: benchmark_metrics.get( + deployment_config.deployment_args.instance_type + ) + } + + configs.append(deployment_config_json) + return configs + + +def _deployment_config_lru_cache(_func=None, *, maxsize: int = 128, typed: bool = False): + """LRU cache for deployment configs.""" + + def has_instance_rate_metric(config: DeploymentConfigMetadata) -> bool: + """Determines whether metadata config contains instance rate metric stat. + + Args: + config (DeploymentConfigMetadata): Metadata config metadata. + Returns: + bool: Whether the metadata config contains instance rate metric stat. + """ + if config.benchmark_metrics is None: + return True + for benchmark_metric_stats in config.benchmark_metrics.values(): + if not has_instance_rate_stat(benchmark_metric_stats): + return False + return True + + def wrapper_cache(f): + f = lru_cache(maxsize=maxsize, typed=typed)(f) + + @wraps(f) + def wrapped_f(*args, **kwargs): + res = f(*args, **kwargs) + + # Clear cache on first call if + # - The output does not contain Instant rate metrics + # as this is caused by missing policy. + if f.cache_info().hits == 0 and f.cache_info().misses == 1: + if isinstance(res, list): + for item in res: + if isinstance( + item, DeploymentConfigMetadata + ) and not has_instance_rate_metric(item): + f.cache_clear() + break + elif isinstance(res, dict): + keys = list(res.keys()) + if len(keys) == 0 or "Instance Rate" not in keys[-1]: + f.cache_clear() + elif len(res[keys[1]]) > len(res[keys[-1]]): + del res[keys[-1]] + f.cache_clear() + return res + + wrapped_f.cache_info = f.cache_info + wrapped_f.cache_clear = f.cache_clear + return wrapped_f + + if _func is None: + return wrapper_cache + return wrapper_cache(_func) + + +def _add_model_access_configs_to_model_data_sources( + model_data_sources: List[Dict[str, any]], + model_access_configs: Dict[str, ModelAccessConfig], + model_id: str, + region: str, +) -> List[Dict[str, any]]: + """Iterate over the accept EULA configs to ensure all channels are matched + + Args: + model_data_sources (DeploymentConfigMetadata): Model data sources that will be updated + model_access_configs (DeploymentConfigMetadata): Config holding accept_eula field + model_id (DeploymentConfigMetadata): Jumpstart model id. + region (str): Region where the user is operating in. + Returns: + List[Dict[str, Any]]: List of model data sources with accept EULA configs applied + Raise: + ValueError if at least one channel that requires EULA acceptance as not passed. + """ + if not model_data_sources: + return model_data_sources + + acked_model_data_sources = [] + for model_data_source in model_data_sources: + hosting_eula_key = model_data_source.get("HostingEulaKey") + mutable_model_data_source = model_data_source.copy() + if hosting_eula_key: + if ( + not model_access_configs + or not model_access_configs.get(model_id) + or not model_access_configs.get(model_id).accept_eula + ): + eula_message_template = ( + "{model_source}{base_eula_message}{model_access_configs_message}" + ) + model_access_config_entry = ( + '"{model_id}":ModelAccessConfig(accept_eula=True)'.format(model_id=model_id) + ) + raise ValueError( + eula_message_template.format( + model_source="Additional " if model_data_source.get("ChannelName") else "", + base_eula_message=get_formatted_eula_message_template( + model_id=model_id, region=region, hosting_eula_key=hosting_eula_key + ), + model_access_configs_message=( + "Please add a ModelAccessConfig entry:" + f" {model_access_config_entry} " + "to model_access_configs to accept the EULA." + ), + ) + ) + mutable_model_data_source.pop( + "HostingEulaKey" + ) # pop when model access config is applied + mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = ( + camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump()) + ) + acked_model_data_sources.append(mutable_model_data_source) + else: + if "HostingEulaKey" in mutable_model_data_source: + mutable_model_data_source.pop( + "HostingEulaKey" + ) # pop when model access config is not applicable + acked_model_data_sources.append(mutable_model_data_source) + return acked_model_data_sources + + +def get_draft_model_content_bucket(provider: Dict, region: str) -> str: + """Returns the correct content bucket for a 1p draft model.""" + neo_bucket = get_neo_content_bucket(region=region) + if not provider: + return neo_bucket + provider_name = provider.get("name", "") + if provider_name == "JumpStart": + classification = provider.get("classification", "ungated") + if classification == "gated": + return get_jumpstart_gated_content_bucket(region=region) + return get_jumpstart_content_bucket(region=region) + return neo_bucket + + +def remove_env_var_from_estimator_kwargs_if_model_access_config_present( + init_kwargs: dict, model_access_config: Optional[dict] +): + """Remove env vars if ModelAccessConfig is used + + Args: + init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated. + accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). + """ + if ( + model_access_config is not None + and init_kwargs.get("environment") is not None + and init_kwargs.get("model_uri") is not None + ): + if ( + constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY + in init_kwargs["environment"] + ): + del init_kwargs["environment"][ + constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY + ] + if "accept_eula" in init_kwargs["environment"]: + del init_kwargs["environment"]["accept_eula"] + + +def get_hub_access_config(hub_content_arn: Optional[str]): + """Get hub access config + + Args: + hub_content_arn (Optional[bool]): Arn of the model reference hub content + """ + if hub_content_arn is not None: + hub_access_config = {"HubContentArn": hub_content_arn} + else: + hub_access_config = None + + return hub_access_config + + +def get_model_access_config(accept_eula: Optional[bool], environment: Optional[dict]): + """Get access configs + + Args: + accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit(). + """ + env_var_eula = environment.get("accept_eula") if environment else None + if env_var_eula is not None and accept_eula is not None: + raise ValueError( + "Cannot pass in both accept_eula and environment variables. " + "Please remove the environment variable and pass in the accept_eula parameter." + ) + + model_access_config = None + if env_var_eula is not None: + model_access_config = {"AcceptEula": env_var_eula == "true"} + if accept_eula is not None: + model_access_config = {"AcceptEula": accept_eula} + + return model_access_config + + +def get_latest_version(versions: List[str]) -> Optional[str]: + """Returns the latest version using sem-ver when possible.""" + try: + return None if not versions else max(versions, key=Version) + except InvalidVersion: + return max(versions) diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index c7098a1185..ea8041d1ee 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -167,10 +167,12 @@ def validate_hyperparameters( model_version: str, hyperparameters: Dict[str, Any], validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, + hub_arn: Optional[str] = None, region: Optional[str] = None, sagemaker_session: Optional[session.Session] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + config_name: Optional[str] = None, ) -> None: """Validate hyperparameters for JumpStart models. @@ -193,6 +195,7 @@ def validate_hyperparameters( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Raises: JumpStartHyperparametersError: If the hyperparameters are not formatted correctly, @@ -213,11 +216,13 @@ def validate_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=region, scope=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + config_name=config_name, ) hyperparameters_specs = model_specs.hyperparameters diff --git a/src/sagemaker/lineage/_api_types.py b/src/sagemaker/lineage/_api_types.py index eb73a1bb39..20baedf383 100644 --- a/src/sagemaker/lineage/_api_types.py +++ b/src/sagemaker/lineage/_api_types.py @@ -207,7 +207,7 @@ def __init__(self, user_profile_arn=None, user_profile_name=None, domain_id=None user_profile_arn=user_profile_arn, user_profile_name=user_profile_name, domain_id=domain_id, - **kwargs + **kwargs, ) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 2ce37f68bd..0cf6c6d55a 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -213,6 +213,10 @@ def start(self, input_data_config, output_data_config, hyperparameters, environm hyperparameters (dict): The HyperParameters for the training job. environment (dict): The collection of environment variables passed to the job. job_name (str): Name of the local training job being run. + + Raises: + ValueError: If the input data configuration is not valid. + RuntimeError: If the data distribution type is not supported. """ for channel in input_data_config: if channel["DataSource"] and "S3DataSource" in channel["DataSource"]: @@ -233,10 +237,12 @@ def start(self, input_data_config, output_data_config, hyperparameters, environm # use a single Data URI - this makes handling S3 and File Data easier down the stack channel["DataUri"] = data_uri - if data_distribution and data_distribution != "FullyReplicated": + supported_distributions = ["FullyReplicated"] + if data_distribution and data_distribution not in supported_distributions: raise RuntimeError( - "DataDistribution: %s is not currently supported in Local Mode" - % data_distribution + "Invalid DataDistribution: '{}'. Local mode currently supports: {}.".format( + data_distribution, ", ".join(supported_distributions) + ) ) self.start_time = datetime.datetime.now() @@ -839,10 +845,10 @@ def _initialize_and_validate_parameters(self, overridden_parameters): ) raise ClientError(error_msg, "start_pipeline_execution") parameter_type = default_parameters[param_name].parameter_type - if type(param_value) != parameter_type.python_type: # pylint: disable=C0123 + if not isinstance(param_value, parameter_type.python_type): error_msg = self._construct_validation_exception_message( - "Unexpected type for parameter '{}'. Expected {} but found " - "{}.".format(param_name, parameter_type.python_type, type(param_value)) + f"Unexpected type for parameter '{param_name}'. Expected \ + {parameter_type.python_type} but found {type(param_value)}." ) raise ClientError(error_msg, "start_pipeline_execution") if param_value == "": diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 32437a59c3..3d0f8394ab 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -30,7 +30,6 @@ import tarfile import tempfile -from distutils.spawn import find_executable from threading import Thread from typing import Dict, List from six.moves.urllib.parse import urlparse @@ -170,7 +169,7 @@ def _get_compose_cmd_prefix(): compose_cmd_prefix.extend(["docker", "compose"]) return compose_cmd_prefix - if find_executable("docker-compose") is not None: + if shutil.which("docker-compose") is not None: logger.info("'Docker Compose' found using Docker Compose CLI.") compose_cmd_prefix.extend(["docker-compose"]) return compose_cmd_prefix @@ -474,7 +473,12 @@ def write_processing_config_files( """ config_path = os.path.join(self.container_root, host, "config") - resource_config = {"current_host": host, "hosts": self.hosts} + resource_config = { + "current_host": host, + "hosts": self.hosts, + "network_interface_name": "eth0", + "current_instance_type": self.instance_type, + } _write_json_file(os.path.join(config_path, "resourceconfig.json"), resource_config) processing_job_config = { @@ -520,7 +524,12 @@ def write_config_files(self, host, hyperparameters, input_data_config): """ config_path = os.path.join(self.container_root, host, "input", "config") - resource_config = {"current_host": host, "hosts": self.hosts} + resource_config = { + "current_host": host, + "hosts": self.hosts, + "network_interface_name": "eth0", + "current_instance_type": self.instance_type, + } json_input_data_config = {} for c in input_data_config: diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 36a848aa52..89a2df2135 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -42,6 +42,8 @@ _LocalPipeline, ) from sagemaker.session import Session +from sagemaker.telemetry.telemetry_logging import _telemetry_emitter +from sagemaker.telemetry.constants import Feature from sagemaker.utils import ( get_config_value, _module_import_error, @@ -83,6 +85,7 @@ def __init__(self, sagemaker_session=None): """ self.sagemaker_session = sagemaker_session or LocalSession() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_processing_job") def create_processing_job( self, ProcessingJobName, @@ -165,6 +168,7 @@ def describe_processing_job(self, ProcessingJobName): raise ClientError(error_response, "describe_processing_job") return LocalSagemakerClient._processing_jobs[ProcessingJobName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_training_job") def create_training_job( self, TrainingJobName, @@ -235,6 +239,7 @@ def describe_training_job(self, TrainingJobName): raise ClientError(error_response, "describe_training_job") return LocalSagemakerClient._training_jobs[TrainingJobName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_transform_job") def create_transform_job( self, TransformJobName, @@ -280,6 +285,7 @@ def describe_transform_job(self, TransformJobName): raise ClientError(error_response, "describe_transform_job") return LocalSagemakerClient._transform_jobs[TransformJobName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_model") def create_model( self, ModelName, PrimaryContainer, *args, **kwargs ): # pylint: disable=unused-argument @@ -329,6 +335,7 @@ def describe_endpoint_config(self, EndpointConfigName): raise ClientError(error_response, "describe_endpoint_config") return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_endpoint_config") def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None): """Create the endpoint configuration. @@ -360,6 +367,7 @@ def describe_endpoint(self, EndpointName): raise ClientError(error_response, "describe_endpoint") return LocalSagemakerClient._endpoints[EndpointName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_endpoint") def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None): """Create the endpoint. @@ -428,6 +436,7 @@ def delete_model(self, ModelName): if ModelName in LocalSagemakerClient._models: del LocalSagemakerClient._models[ModelName] + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_pipeline") def create_pipeline( self, pipeline, pipeline_description, **kwargs # pylint: disable=unused-argument ): diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index 950d0974db..3c7c3cda61 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -21,7 +21,6 @@ import re import errno -from distutils.dir_util import copy_tree from six.moves.urllib.parse import urlparse from sagemaker import s3 @@ -102,7 +101,7 @@ def move_to_destination(source, destination, job_name, sagemaker_session, prefix def recursive_copy(source, destination): - """A wrapper around distutils.dir_util.copy_tree. + """A wrapper around shutil.copy_tree. This won't throw any exception when the source directory does not exist. @@ -111,7 +110,7 @@ def recursive_copy(source, destination): destination (str): destination path """ if os.path.isdir(source): - copy_tree(source, destination) + shutil.copytree(source, destination, dirs_exist_ok=True) def kill_child_processes(pid): @@ -154,7 +153,8 @@ def get_child_process_ids(pid): def get_docker_host(): """Discover remote docker host address (if applicable) or use "localhost" - Use "docker context inspect" to read current docker host endpoint url, + When rootlessDocker is enabled (Cgroup Driver: none), use fixed SageMaker IP. + Otherwise, Use "docker context inspect" to read current docker host endpoint url, url must start with "tcp://" Args: @@ -162,6 +162,27 @@ def get_docker_host(): Returns: docker_host (str): Docker host DNS or IP address """ + # Check if using SageMaker rootless Docker by examining storage driver + try: + cmd = ["docker", "info"] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, err = process.communicate() + if process.returncode == 0: # Check return code instead of stderr + output_text = output.decode("utf-8") + # Check for rootless Docker by looking at Cgroup Driver + if "Cgroup Driver: none" in output_text: + # log the result of check + logger.warning("RootlessDocker detected (Cgroup Driver: none), returning fixed IP.") + # SageMaker rootless Docker detected - return fixed IP + return "172.17.0.1" + else: + logger.warning( + "RootlessDocker not detected, falling back to remote host IP or localhost." + ) + except subprocess.SubprocessError as e: + logger.warning("Failed to run 'docker info' command when checking rootlessDocker: %s.", e) + + # Fallback to existing logic for remote Docker hosts cmd = "docker context inspect".split() process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, err = process.communicate() diff --git a/src/sagemaker/logs.py b/src/sagemaker/logs.py index 7e80d796e7..56c532021d 100644 --- a/src/sagemaker/logs.py +++ b/src/sagemaker/logs.py @@ -166,7 +166,7 @@ def log_stream(client, log_group, stream_name, start_time=0, skip=0): logStreamName=stream_name, startTime=start_time, startFromHead=True, - **token_arg + **token_arg, ) next_token = response["nextForwardToken"] events = response["events"] diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 71dd26db45..8b7d80b48d 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -20,6 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -29,10 +30,13 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model matching the given arguments. @@ -43,6 +47,8 @@ def retrieve_default( retrieve the default training metric definitions. (Default: None). model_version (str): The version of the model for which to retrieve the default training metric definitions. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -56,6 +62,9 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: list: The default metric definitions to use for the model or None. @@ -71,9 +80,12 @@ def retrieve_default( return artifacts._retrieve_default_training_metric_definitions( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) diff --git a/src/sagemaker/mlflow/__init__.py b/src/sagemaker/mlflow/__init__.py new file mode 100644 index 0000000000..6549052177 --- /dev/null +++ b/src/sagemaker/mlflow/__init__.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/src/sagemaker/mlflow/forward_sagemaker_metrics.py b/src/sagemaker/mlflow/forward_sagemaker_metrics.py new file mode 100644 index 0000000000..48b217482c --- /dev/null +++ b/src/sagemaker/mlflow/forward_sagemaker_metrics.py @@ -0,0 +1,315 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""This module contains code related to forwarding SageMaker TrainingJob Metrics to MLflow.""" + +from __future__ import absolute_import + +import os +import platform +import re +from typing import Set, Tuple, List, Dict, Generator +import boto3 + +try: + import mlflow +except ImportError: + raise ValueError("Unable to import mlflow, check if sagemaker-mlflow is installed.") +from mlflow import MlflowClient +from mlflow.entities import Metric, Param, RunTag + +from packaging import version + + +def encode(name: str, existing_names: Set[str]) -> str: + """Encode a string to comply with MLflow naming restrictions and ensure uniqueness. + + Args: + name (str): The original string to be encoded. + existing_names (Set[str]): Set of existing encoded names to avoid collisions. + + Returns: + str: The encoded string if changes were necessary, otherwise the original string. + """ + + def encode_char(match): + return f"_{ord(match.group(0)):02x}_" + + # Check if we're on Mac/Unix and using MLflow 2.16.0 or greater + is_unix = platform.system() != "Windows" + mlflow_version = version.parse(mlflow.__version__) + allow_colon = is_unix and mlflow_version >= version.parse("2.16.0") + + if allow_colon: + pattern = r"[^\w\-./:\s]" + else: + pattern = r"[^\w\-./\s]" + + encoded = re.sub(pattern, encode_char, name) + base_name = encoded[:240] # Leave room for potential suffix to accommodate duplicates + + if base_name in existing_names: + suffix = 1 + # Edge case where even with suffix space there is a collision + # we will override one of the keys. + while f"{base_name}_{suffix}" in existing_names: + suffix += 1 + encoded = f"{base_name}_{suffix}" + + # Max length is 250 for mlflow metric/params + encoded = encoded[:250] + + existing_names.add(encoded) + return encoded + + +def decode(encoded_metric_name: str) -> str: + """Decodes an encoded metric name by replacing hexadecimal representations with ASCII + + This function reverses the encoding process by converting hexadecimal codes + back to their original characters. It looks for patterns of the form "_XX_" + where XX is a two-digit hexadecimal code, and replaces them with the + corresponding ASCII character. + + Args: + encoded_metric_name (str): The encoded metric name to be decoded. + + Returns: + str: The decoded metric name with hexadecimal codes replaced by their + corresponding characters. + + Example: + >>> decode("loss_3a_val") + "loss:val" + """ + + def replace_code(match): + code = match.group(1) + return chr(int(code, 16)) + + # Replace encoded characters + decoded = re.sub(r"_([0-9a-f]{2})_", replace_code, encoded_metric_name) + + return decoded + + +def get_training_job_details(job_arn: str) -> dict: + """Retrieve details of a SageMaker training job. + + Args: + job_arn (str): The ARN of the SageMaker training job. + + Returns: + dict: A dictionary containing the details of the training job. + + Raises: + boto3.exceptions.Boto3Error: If there's an issue with the AWS API call. + """ + sagemaker_client = boto3.client("sagemaker") + job_name = job_arn.split("/")[-1] + return sagemaker_client.describe_training_job(TrainingJobName=job_name) + + +def create_metric_queries(job_arn: str, metric_definitions: list) -> list: + """Create metric queries for SageMaker metrics. + + Args: + job_arn (str): The ARN of the SageMaker training job. + metric_definitions (list): List of metric definitions from the training job. + + Returns: + list: A list of dictionaries, each representing a metric query. + """ + metric_queries = [] + for metric in metric_definitions: + query = { + "MetricName": metric["Name"], + "XAxisType": "Timestamp", + "MetricStat": "Avg", + "Period": "OneMinute", + "ResourceArn": job_arn, + } + metric_queries.append(query) + return metric_queries + + +def get_metric_data(metric_queries: list) -> dict: + """Retrieve metric data from SageMaker. + + Args: + metric_queries (list): A list of metric queries. + + Returns: + dict: A dictionary containing the metric data results. + + Raises: + boto3.exceptions.Boto3Error: If there's an issue with the AWS API call. + """ + sagemaker_metrics_client = boto3.client("sagemaker-metrics") + metric_data = sagemaker_metrics_client.batch_get_metrics(MetricQueries=metric_queries) + return metric_data + + +def prepare_mlflow_metrics( + metric_queries: list, metric_results: list +) -> Tuple[List[Metric], Dict[str, str]]: + """Prepare metrics for MLflow logging, encoding metric names if necessary. + + Args: + metric_queries (list): The original metric queries sent to SageMaker. + metric_results (list): The metric results from SageMaker batch_get_metrics. + + Returns: + Tuple[List[Metric], Dict[str, str]]: + - A list of Metric objects with encoded names (if necessary) + - A mapping of encoded to original names for metrics (only for encoded metrics) + """ + mlflow_metrics = [] + metric_name_mapping = {} + existing_names = set() + + for query, result in zip(metric_queries, metric_results): + if result["Status"] == "Complete": + metric_name = query["MetricName"] + encoded_name = encode(metric_name, existing_names) + metric_name_mapping[encoded_name] = metric_name + + for step, (timestamp, value) in enumerate( + zip(result["XAxisValues"], result["MetricValues"]) + ): + metric = Metric(key=encoded_name, value=value, timestamp=timestamp, step=step) + mlflow_metrics.append(metric) + + return mlflow_metrics, metric_name_mapping + + +def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param], Dict[str, str]]: + """Prepare hyperparameters for MLflow logging, encoding parameter names if necessary. + + Args: + hyperparameters (Dict[str, str]): The hyperparameters from the SageMaker job. + + Returns: + Tuple[List[Param], Dict[str, str]]: + - A list of Param objects with encoded names (if necessary) + - A mapping of encoded to original names for + hyperparameters (only for encoded parameters) + """ + mlflow_params = [] + param_name_mapping = {} + existing_names = set() + + for key, value in hyperparameters.items(): + encoded_key = encode(key, existing_names) + param_name_mapping[encoded_key] = key + mlflow_params.append(Param(encoded_key, str(value))) + + return mlflow_params, param_name_mapping + + +def batch_items(items: list, batch_size: int) -> Generator: + """Yield successive batch_size chunks from items. + + Args: + items (list): The list of items to be batched. + batch_size (int): The size of each batch. + + Yields: + list: A batch of items. + """ + for i in range(0, len(items), batch_size): + yield items[i : i + batch_size] + + +def log_to_mlflow(metrics: list, params: list, tags: dict) -> None: + """Log metrics, parameters, and tags to MLflow. + + Args: + metrics (list): List of metrics to log. + params (list): List of parameters to log. + tags (dict): Dictionary of tags to set. + + Raises: + mlflow.exceptions.MlflowException: If there's an issue with MLflow logging. + """ + client = MlflowClient() + + experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME") + if experiment_name is None or experiment_name.strip() == "": + experiment_name = "Default" + print("MLFLOW_EXPERIMENT_NAME not set. Using Default") + + experiment = client.get_experiment_by_name(experiment_name) + if experiment is None: + experiment_id = client.create_experiment(experiment_name) + else: + experiment_id = experiment.experiment_id + + run = client.create_run(experiment_id) + + for metric_batch in batch_items(metrics, 1000): + client.log_batch( + run.info.run_id, + metrics=metric_batch, + ) + for param_batch in batch_items(params, 1000): + client.log_batch(run.info.run_id, params=param_batch) + + tag_items = list(tags.items()) + for tag_batch in batch_items(tag_items, 1000): + tag_objects = [RunTag(key, str(value)) for key, value in tag_batch] + client.log_batch(run.info.run_id, tags=tag_objects) + client.set_terminated(run.info.run_id) + + +def log_sagemaker_job_to_mlflow(training_job_arn: str) -> None: + """Retrieve SageMaker metrics and hyperparameters and log them to MLflow. + + Args: + training_job_arn (str): The ARN of the SageMaker training job. + + Raises: + Exception: If there's any error during the process. + """ + # Get training job details + mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI")) + job_details = get_training_job_details(training_job_arn) + + # Extract hyperparameters and metric definitions + hyperparameters = job_details["HyperParameters"] + metric_definitions = job_details["AlgorithmSpecification"]["MetricDefinitions"] + + # Create and get metric queries + metric_queries = create_metric_queries(job_details["TrainingJobArn"], metric_definitions) + metric_data = get_metric_data(metric_queries) + + # Create a mapping of encoded to original metric names + # Prepare data for MLflow + mlflow_metrics, metric_name_mapping = prepare_mlflow_metrics( + metric_queries, metric_data["MetricQueryResults"] + ) + + # Create a mapping of encoded to original hyperparameter names + # Prepare data for MLflow + mlflow_params, param_name_mapping = prepare_mlflow_params(hyperparameters) + + mlflow_tags = { + "training_job_arn": training_job_arn, + "metric_name_mapping": str(metric_name_mapping), + "param_name_mapping": str(param_name_mapping), + } + + # Log to MLflow + log_to_mlflow(mlflow_metrics, mlflow_params, mlflow_tags) + print(f"Logged {len(mlflow_metrics)} metric datapoints to MLflow") + print(f"Logged {len(mlflow_params)} hyperparameters to MLflow") diff --git a/src/sagemaker/mlflow/tracking_server.py b/src/sagemaker/mlflow/tracking_server.py new file mode 100644 index 0000000000..0baa0f457b --- /dev/null +++ b/src/sagemaker/mlflow/tracking_server.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + + +"""This module contains code related to the Mlflow Tracking Server.""" + +from __future__ import absolute_import +from typing import Optional, TYPE_CHECKING +from sagemaker.apiutils import _utils + +if TYPE_CHECKING: + from sagemaker import Session + + +def generate_mlflow_presigned_url( + name: str, + expires_in_seconds: Optional[int] = None, + session_expiration_duration_in_seconds: Optional[int] = None, + sagemaker_session: Optional["Session"] = None, +) -> str: + """Generate a presigned url to acess the Mlflow UI. + + Args: + name (str): Name of the Mlflow Tracking Server + expires_in_seconds (int): Expiration time of the first usage + of the presigned url in seconds. + session_expiration_duration_in_seconds (int): Session duration of the presigned url in + seconds after the first use. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + Returns: + (str): Authorized Url to acess the Mlflow UI. + """ + session = sagemaker_session or _utils.default_session() + api_response = session.create_presigned_mlflow_tracking_server_url( + name, expires_in_seconds, session_expiration_duration_in_seconds + ) + return api_response["AuthorizedUrl"] diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index fd21b6342e..3bfac0c8da 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -20,7 +20,7 @@ import os import re import copy -from typing import List, Dict, Optional, Union +from typing import Callable, List, Dict, Optional, Union, Any import sagemaker from sagemaker import ( @@ -44,10 +44,15 @@ ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH, load_sagemaker_config, ) +from sagemaker.jumpstart.enums import JumpStartModelType +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) +from sagemaker.model_card.helpers import _hash_content_str from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum from sagemaker.session import Session from sagemaker.model_metrics import ModelMetrics -from sagemaker.deprecations import removed_kwargs from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.explainer import ExplainerConfig from sagemaker.metadata_properties import MetadataProperties @@ -66,6 +71,9 @@ resolve_nested_dict_value_from_config, format_tags, Tags, + _resolve_routing_config, + _validate_new_tags, + remove_tag_with_key, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor @@ -81,6 +89,7 @@ get_add_model_package_inference_args, get_update_model_package_inference_args, ) +from sagemaker.model_life_cycle import ModelLifeCycle # Setting LOGGER for backward compatibility, in case users import it... logger = LOGGER = logging.getLogger("sagemaker") @@ -144,7 +153,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, @@ -159,6 +168,8 @@ def __init__( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, resources: Optional[ResourceRequirements] = None, + additional_model_data_sources: Optional[Dict[str, Any]] = None, + model_reference_arn: Optional[str] = None, ): """Initialize an SageMaker ``Model``. @@ -174,7 +185,7 @@ def __init__( It can be null if this is being used to create a Model to pass to a ``PipelineModel`` which has its own Role field. (default: None) - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. @@ -203,8 +214,8 @@ def __init__( source_dir (str): The absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory is preserved - when training on Amazon SageMaker. If 'git_config' is provided, + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + is preserved when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. If the directory points to S3, no code is uploaded and the S3 location is used instead. @@ -322,9 +333,14 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). + additional_model_data_sources (Optional[Dict[str, Any]]): Additional location + of SageMaker model data (default: None). + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + content (default: None). """ self.model_data = model_data + self.additional_model_data_sources = additional_model_data_sources self.image_uri = image_uri self.predictor_cls = predictor_cls self.name = name @@ -353,7 +369,9 @@ def __init__( sagemaker_config=self._sagemaker_config, ) self.endpoint_name = None + self.inference_component_name = None self._is_compiled_model = False + self._is_sharded_model = False self._compilation_job_name = None self._is_edge_packaged_model = False self.inference_recommender_job_results = None @@ -399,6 +417,34 @@ def __init__( self.content_types = None self.response_types = None self.accept_eula = None + self.model_reference_arn = model_reference_arn + self._tags: Optional[Tags] = None + + def add_tags(self, tags: Tags) -> None: + """Add tags to this ``Model`` + + Args: + tags (Tags): Tags to add. + """ + self._tags = _validate_new_tags(tags, self._tags) + + def remove_tag_with_key(self, key: str) -> None: + """Remove a tag with the given key from the list of tags. + + Args: + key (str): The key of the tag to remove. + """ + self._tags = remove_tag_with_key(key, self._tags) + + @classmethod + def attach( + cls, + endpoint_name: str, + inference_component_name: Optional[str] = None, + sagemaker_session=None, + ) -> "Model": + """Attaches a Model object to an existing SageMaker Endpoint.""" + raise NotImplementedError @runnable_by_pipeline def register( @@ -427,6 +473,10 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, + accept_eula: Optional[bool] = None, + model_type: Optional[JumpStartModelType] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -478,6 +528,9 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments @@ -499,9 +552,11 @@ def register( model_package_group_name = utils.base_name_from_image( self.image_uri, default_base_name=ModelPackage.__name__ ) - - if model_package_group_name is not None: - container_def = self.prepare_container_def() + if ( + model_package_group_name is not None + and model_type is not JumpStartModelType.PROPRIETARY + ): + container_def = self.prepare_container_def(accept_eula=accept_eula) container_def = update_container_with_inference_params( framework=framework, framework_version=framework_version, @@ -544,6 +599,8 @@ def register( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + model_life_cycle=model_life_cycle, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args @@ -566,6 +623,7 @@ def create( serverless_inference_config: Optional[ServerlessInferenceConfig] = None, tags: Optional[Tags] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None, ): """Create a SageMaker Model Entity @@ -607,6 +665,7 @@ def create( tags=format_tags(tags), serverless_inference_config=serverless_inference_config, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): @@ -628,6 +687,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): # pylint: disable=unused-argument """Return a dict created by ``sagemaker.container_def()``. @@ -670,6 +730,12 @@ def prepare_container_def( accept_eula=( accept_eula if accept_eula is not None else getattr(self, "accept_eula", None) ), + additional_model_data_sources=self.additional_model_data_sources, + model_reference_arn=( + model_reference_arn + if model_reference_arn is not None + else getattr(self, "model_reference_arn", None) + ), ) def is_repack(self) -> bool: @@ -678,6 +744,8 @@ def is_repack(self) -> bool: Returns: bool: if the source need to be repacked or not """ + if self.source_dir is None or self.entry_point is None: + return False return self.source_dir and self.entry_point and not self.git_config def _upload_code(self, key_prefix: str, repack: bool = False) -> None: @@ -812,6 +880,7 @@ def _create_sagemaker_model( tags: Optional[Tags] = None, serverless_inference_config=None, accept_eula=None, + model_reference_arn: Optional[str] = None, ): """Create a SageMaker Model Entity @@ -836,6 +905,8 @@ def _create_sagemaker_model( The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None). + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + content (default: None). """ if self.model_package_arn is not None or self.algorithm_arn is not None: model_package = ModelPackage( @@ -867,6 +938,7 @@ def _create_sagemaker_model( accelerator_type=accelerator_type, serverless_inference_config=serverless_inference_config, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) if not isinstance(self.sagemaker_session, PipelineSession): @@ -1309,6 +1381,11 @@ def deploy( resources: Optional[ResourceRequirements] = None, endpoint_type: EndpointType = EndpointType.MODEL_BASED, managed_instance_scaling: Optional[str] = None, + inference_component_name=None, + routing_config: Optional[Dict[str, Any]] = None, + model_reference_arn: Optional[str] = None, + inference_ami_version: Optional[str] = None, + update_endpoint: Optional[bool] = False, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -1406,6 +1483,25 @@ def deploy( Endpoint. (Default: None). endpoint_type (Optional[EndpointType]): The type of an endpoint used to deploy models. (Default: EndpointType.MODEL_BASED). + routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes incoming + traffic to the instances that the endpoint hosts. + Currently, support dictionary key ``RoutingStrategy``. + + .. code:: python + + { + "RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM + } + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + content (default: None). + inference_ami_version (Optional [str]): Specifies an option from a collection of preconfigured + Amazon Machine Image (AMI) images. For a full list of options, see: + https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html + update_endpoint (Optional[bool]): + Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint + and delete resources corresponding to the previous EndpointConfig. Default: False + Note: Currently this is supported for single model endpoints Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -1415,14 +1511,12 @@ def deploy( inference config or - If inference recommendation id is specified along with incompatible parameters Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Callable[[string, sagemaker.session.Session], Any] or None: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ self.accept_eula = accept_eula - removed_kwargs("update_endpoint", kwargs) - self._init_sagemaker_session_if_does_not_exist(instance_type) # Depending on the instance type, a local session (or) a session is initialized. self.role = resolve_value_from_config( @@ -1441,7 +1535,8 @@ def deploy( sagemaker_session=self.sagemaker_session, ) - tags = format_tags(tags) + self.add_tags(tags) + tags = format_tags(self._tags) if ( getattr(self.sagemaker_session, "settings", None) is not None @@ -1458,6 +1553,8 @@ def deploy( if self.role is None: raise ValueError("Role can not be null for deploying a model") + routing_config = _resolve_routing_config(routing_config) + if ( inference_recommendation_id is not None or self.inference_recommender_job_results is not None @@ -1512,8 +1609,32 @@ def deploy( if self._base_name is not None: self._base_name = "-".join((self._base_name, compiled_model_suffix)) + if self._is_sharded_model: + if endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED: + logging.warning( + "Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - " + "Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints." + ) + endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED + + if self._enable_network_isolation: + raise ValueError( + "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " + "Loading of model requires network access." + ) + + if resources and resources.num_cpus and resources.num_cpus > 0: + logger.warning( + "NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker " + "Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`." + ) + # Support multiple models on same endpoint if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED: + if update_endpoint: + raise ValueError( + "Currently update_endpoint is supported for single model endpoints" + ) if endpoint_name: self.endpoint_name = endpoint_name else: @@ -1543,6 +1664,8 @@ def deploy( model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, managed_instance_scaling=managed_instance_scaling_config, + routing_config=routing_config, + inference_ami_version=inference_ami_version, ) self.sagemaker_session.endpoint_from_production_variants( @@ -1553,7 +1676,7 @@ def deploy( vpc_config=self.vpc_config, enable_network_isolation=self._enable_network_isolation, role=self.role, - live_logging=endpoint_logging, + live_logging=False, # TODO: enable when IC supports this wait=wait, ) @@ -1580,11 +1703,15 @@ def deploy( "ComputeResourceRequirements": resources.get_compute_resource_requirements(), } runtime_config = {"CopyCount": resources.copy_count} - inference_component_name = unique_name_from_base(self.name) + self.inference_component_name = ( + inference_component_name + or self.inference_component_name + or unique_name_from_base(self.name) + ) # [TODO]: Add endpoint_logging support self.sagemaker_session.create_inference_component( - inference_component_name=inference_component_name, + inference_component_name=self.inference_component_name, endpoint_name=self.endpoint_name, variant_name="AllTraffic", # default variant name specification=inference_component_spec, @@ -1597,7 +1724,7 @@ def deploy( predictor = self.predictor_cls( self.endpoint_name, self.sagemaker_session, - component_name=inference_component_name, + component_name=self.inference_component_name, ) if serializer: predictor.serializer = serializer @@ -1612,6 +1739,8 @@ def deploy( accelerator_type=accelerator_type, tags=tags, serverless_inference_config=serverless_inference_config, + accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) serverless_inference_config_dict = ( serverless_inference_config._to_request_dict() if is_serverless else None @@ -1625,6 +1754,8 @@ def deploy( volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, + routing_config=routing_config, + inference_ami_version=inference_ami_version, ) if endpoint_name: self.endpoint_name = endpoint_name @@ -1659,17 +1790,38 @@ def deploy( if is_explainer_enabled: explainer_config_dict = explainer_config._to_request_dict() - self.sagemaker_session.endpoint_from_production_variants( - name=self.endpoint_name, - production_variants=[production_variant], - tags=tags, - kms_key=kms_key, - wait=wait, - data_capture_config_dict=data_capture_config_dict, - explainer_config_dict=explainer_config_dict, - async_inference_config_dict=async_inference_config_dict, - live_logging=endpoint_logging, - ) + if update_endpoint: + endpoint_config_name = self.sagemaker_session.create_endpoint_config( + name=self.name, + model_name=self.name, + initial_instance_count=initial_instance_count, + instance_type=instance_type, + accelerator_type=accelerator_type, + tags=tags, + kms_key=kms_key, + data_capture_config_dict=data_capture_config_dict, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, + explainer_config_dict=explainer_config_dict, + async_inference_config_dict=async_inference_config_dict, + serverless_inference_config_dict=serverless_inference_config_dict, + routing_config=routing_config, + inference_ami_version=inference_ami_version, + ) + self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name) + else: + self.sagemaker_session.endpoint_from_production_variants( + name=self.endpoint_name, + production_variants=[production_variant], + tags=tags, + kms_key=kms_key, + wait=wait, + data_capture_config_dict=data_capture_config_dict, + explainer_config_dict=explainer_config_dict, + async_inference_config_dict=async_inference_config_dict, + live_logging=endpoint_logging, + ) if self.predictor_cls: predictor = self.predictor_cls(self.endpoint_name, self.sagemaker_session) @@ -1841,7 +1993,7 @@ def __init__( role: Optional[str] = None, entry_point: Optional[str] = None, source_dir: Optional[str] = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, name: Optional[str] = None, container_log_level: Union[int, PipelineVariable] = logging.INFO, @@ -1878,11 +2030,11 @@ def __init__( source_dir (str): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. If 'git_config' is provided, - 'source_dir' should be a relative location to a directory in the Git repo. - If the directory points to S3, no code will be uploaded and the S3 location - will be used instead. + point to a file with name ``sourcedir.tar.gz``. Structure within this + directory are preserved when training on Amazon SageMaker. If 'git_config' + is provided, 'source_dir' should be a relative location to a directory in the + Git repo. If the directory points to S3, no code will be uploaded and the S3 + location will be used instead. .. admonition:: Example @@ -1894,7 +2046,7 @@ def __init__( >>> |----- test.py You can assign entry_point='inference.py', source_dir='src'. - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. @@ -2021,6 +2173,8 @@ def is_repack(self) -> bool: Returns: bool: if the source need to be repacked or not """ + if self.source_dir is None or self.entry_point is None: + return False return self.source_dir and self.entry_point and not (self.key_prefix or self.git_config) @@ -2294,6 +2448,23 @@ def update_source_uri( sagemaker_session = self.sagemaker_session or sagemaker.Session() sagemaker_session.sagemaker_client.update_model_package(**update_source_uri_args) + def update_model_life_cycle( + self, + model_life_cycle: ModelLifeCycle, + ): + """Modellifecycle to be set for the model package + + Args: + model_life_cycle (ModelLifeCycle): The current state of model package in its life cycle + + """ + update_model_life_cycle_args = { + "ModelPackageArn": self.model_package_arn, + "ModelLifeCycle": model_life_cycle, + } + sagemaker_session = self.sagemaker_session or sagemaker.Session() + sagemaker_session.sagemaker_client.update_model_package(**update_model_life_cycle_args) + def remove_customer_metadata_properties( self, customer_metadata_properties_to_remove: List[str] ): @@ -2370,3 +2541,67 @@ def add_inference_specification( ) sagemaker_session.sagemaker_client.update_model_package(**model_package_update_args) + + def update_model_card(self, model_card: Union[ModelCard, ModelPackageModelCard]): + """Updates Created model card content which created with model package + + Args: + model_card (ModelCard | ModelPackageModelCard): Updated Model Card content + """ + + sagemaker_session = self.sagemaker_session or sagemaker.Session() + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=self.model_package_arn + ) + if hasattr(model_card, "model_package_details"): + model_card.model_package_details = None + update_model_card_req = model_card._create_request_args() + if update_model_card_req.get("ModelCardName") is not None: + del update_model_card_req["ModelCardName"] + if update_model_card_req["Content"] is not None: + if "model_package_details" in update_model_card_req["Content"]: + update_model_card_req["Content"].pop("model_package_details", None) + update_model_card_req["ModelCardContent"] = update_model_card_req["Content"] + del update_model_card_req["Content"] + + if "ModelCard" in desc_model_package: + if update_model_card_req["ModelCardStatus"] is not None: + if ( + desc_model_package["ModelCard"]["ModelCardStatus"] + != update_model_card_req["ModelCardStatus"] + ): + new_mc_mp_req = update_model_card_req + del new_mc_mp_req["ModelCardContent"] + update_model_package_args = { + "ModelPackageArn": self.model_package_arn, + "ModelCard": new_mc_mp_req, + } + sagemaker_session.sagemaker_client.update_model_package( + **update_model_package_args + ) + + if update_model_card_req.get("ModelCardContent") is not None: + previous_content_hash = _hash_content_str( + desc_model_package["ModelCard"]["ModelCardContent"] + ) + current_content_hash = _hash_content_str(update_model_card_req["ModelCardContent"]) + if not ( + previous_content_hash == current_content_hash + or update_model_card_req.get("ModelCardContent") == "{}" + or update_model_card_req.get("ModelCardContent") == "null" + ): + new_mc_mp_req = update_model_card_req + del new_mc_mp_req["ModelCardStatus"] + update_model_package_args = { + "ModelPackageArn": self.model_package_arn, + "ModelCard": new_mc_mp_req, + } + sagemaker_session.sagemaker_client.update_model_package( + **update_model_package_args + ) + else: + update_model_package_args = { + "ModelPackageArn": self.model_package_arn, + "ModelCard": update_model_card_req, + } + sagemaker_session.sagemaker_client.update_model_package(**update_model_package_args) diff --git a/src/sagemaker/model_card/__init__.py b/src/sagemaker/model_card/__init__.py index 679da42a3f..b7a7d24dc7 100644 --- a/src/sagemaker/model_card/__init__.py +++ b/src/sagemaker/model_card/__init__.py @@ -29,6 +29,7 @@ AdditionalInformation, ModelCard, ModelPackage, + ModelPackageModelCard, ) from sagemaker.model_card.schema_constraints import ( # noqa: F401 # pylint: disable=unused-import diff --git a/src/sagemaker/model_card/helpers.py b/src/sagemaker/model_card/helpers.py index a8d9e7940e..925d9ae0e0 100644 --- a/src/sagemaker/model_card/helpers.py +++ b/src/sagemaker/model_card/helpers.py @@ -503,12 +503,12 @@ def _read_s3_json(session: Session, bucket: str, key: str): raise result = {} - if data["ContentType"] == "application/json" or data["ContentType"] == "binary/octet-stream": + content_types = ["application/json", "binary/octet-stream", "application/octet-stream"] + if data["ContentType"] in content_types: result = json.loads(data["Body"].read().decode("utf-8")) else: logger.warning( - "Invalid file type %s. application/json or binary/octet-stream is expected.", - data["ContentType"], + "Invalid file type %s. %s is expected.", data["ContentType"], ", ".join(content_types) ) return result diff --git a/src/sagemaker/model_card/model_card.py b/src/sagemaker/model_card/model_card.py index 33af98723f..c13e979efc 100644 --- a/src/sagemaker/model_card/model_card.py +++ b/src/sagemaker/model_card/model_card.py @@ -16,7 +16,7 @@ import json import logging from datetime import datetime -from typing import Optional, Union, List, Any +from typing import Optional, Union, List, Any, Dict from botocore.exceptions import ClientError from boto3.session import Session as boto3_Session from six.moves.urllib.parse import urlparse @@ -1883,3 +1883,29 @@ def list_export_jobs( return sagemaker_session.sagemaker_client.list_model_card_export_jobs( ModelCardName=model_card_name, **kwargs ) + + +class ModelPackageModelCard(object): + """Use an Amazon SageMaker Model Card to document qualitative and quantitative information about a model.""" # noqa E501 # pylint: disable=c0301 + + def __init__( + self, + model_card_content: Optional[Dict[str, Any]] = None, + model_card_status: Optional[str] = None, + ): + + self.model_card_content = model_card_content + self.model_card_status = model_card_status + + def _create_request_args(self): + """Generate the request body for create model card call. + + Args: + model_card_content dict[str]: Content of the model card. + model_card_status (str): Status of the model card you want to export. + + """ # noqa E501 # pylint: disable=line-too-long + request_args = {} + request_args["ModelCardStatus"] = self.model_card_status + request_args["Content"] = json.dumps(self.model_card_content, cls=_JSONEncoder) + return request_args diff --git a/src/sagemaker/model_life_cycle.py b/src/sagemaker/model_life_cycle.py new file mode 100644 index 0000000000..59403e91c8 --- /dev/null +++ b/src/sagemaker/model_life_cycle.py @@ -0,0 +1,51 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This file contains code related to model life cycle.""" +from __future__ import absolute_import + +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + + +class ModelLifeCycle(object): + """Accepts ModelLifeCycle parameters for conversion to request dict.""" + + def __init__( + self, + stage: Optional[Union[str, PipelineVariable]] = None, + stage_status: Optional[Union[str, PipelineVariable]] = None, + stage_description: Optional[Union[str, PipelineVariable]] = None, + ): + """Initialize a ``ModelLifeCycle`` instance and turn parameters into dict. + + # TODO: flesh out docstrings + Args: + stage (str or PipelineVariable): + stage_status (str or PipelineVariable): + stage_description (str or PipelineVariable): + """ + self.stage = stage + self.stage_status = stage_status + self.stage_description = stage_description + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + model_life_cycle_request = dict() + if self.stage: + model_life_cycle_request["Stage"] = self.stage + if self.stage_status: + model_life_cycle_request["StageStatus"] = self.stage_status + if self.stage_description: + model_life_cycle_request["StageDescription"] = self.stage_description + return model_life_cycle_request diff --git a/src/sagemaker/model_monitor/clarify_model_monitoring.py b/src/sagemaker/model_monitor/clarify_model_monitoring.py index 3edfabc747..9dc915a2d7 100644 --- a/src/sagemaker/model_monitor/clarify_model_monitoring.py +++ b/src/sagemaker/model_monitor/clarify_model_monitoring.py @@ -86,11 +86,9 @@ def __init__( object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. """ - if type(self) == __class__: # pylint: disable=unidiomatic-typecheck + if self.__class__ is __class__: raise TypeError( - "{} is abstract, please instantiate its subclasses instead.".format( - __class__.__name__ - ) + f"{__class__.__name__} is abstract, please instantiate its subclasses instead." ) session = sagemaker_session or Session() @@ -1105,6 +1103,8 @@ def create_monitoring_schedule( monitor_schedule_name=monitor_schedule_name, job_definition_name=new_job_definition_name, schedule_cron_expression=schedule_cron_expression, + data_analysis_start_time=data_analysis_start_time, + data_analysis_end_time=data_analysis_end_time, ) self.job_definition_name = new_job_definition_name self.monitoring_schedule_name = monitor_schedule_name diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 436377fea5..3bc29a1cf4 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -2413,7 +2413,12 @@ def _update_data_quality_monitoring_schedule( ) self.sagemaker_session.sagemaker_client.create_data_quality_job_definition(**request_dict) try: - self._update_monitoring_schedule(new_job_definition_name, schedule_cron_expression) + self._update_monitoring_schedule( + job_definition_name=new_job_definition_name, + schedule_cron_expression=schedule_cron_expression, + data_analysis_start_time=data_analysis_start_time, + data_analysis_end_time=data_analysis_end_time, + ) self.job_definition_name = new_job_definition_name if role is not None: self.role = role diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 937180bd44..6f788eb8b9 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session @@ -29,11 +30,14 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the model artifact Amazon S3 URI for the model matching the given arguments. @@ -43,6 +47,8 @@ def retrieve( the model artifact S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). model_scope (str): The model type. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -57,6 +63,10 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). + Returns: str: The model artifact S3 URI for the corresponding model. @@ -75,10 +85,13 @@ def retrieve( return artifacts._retrieve_model_uri( model_id=model_id, model_version=model_version, # type: ignore + hub_arn=hub_arn, model_scope=model_scope, instance_type=instance_type, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) diff --git a/src/sagemaker/modules/__init__.py b/src/sagemaker/modules/__init__.py new file mode 100644 index 0000000000..d7f209f00c --- /dev/null +++ b/src/sagemaker/modules/__init__.py @@ -0,0 +1,19 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""SageMaker modules directory.""" +from __future__ import absolute_import + +from sagemaker_core.main.utils import logger as sagemaker_core_logger +from sagemaker_core.helper.session_helper import Session, get_execution_role # noqa: F401 + +logger = sagemaker_core_logger diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py new file mode 100644 index 0000000000..8fdf88e735 --- /dev/null +++ b/src/sagemaker/modules/configs.py @@ -0,0 +1,307 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides the configuration classes used in ``sagemaker.modules``. + +Some of these classes are re-exported from ``sagemaker_core.shapes``. For convinence, +users can import these classes directly from ``sagemaker.modules.configs``. + +For more documentation on ``sagemaker_core.shapes``, see: + - https://sagemaker-core.readthedocs.io/en/stable/#sagemaker-core-shapes +""" + +from __future__ import absolute_import + +from typing import Optional, Union, List +from pydantic import BaseModel, model_validator, ConfigDict + +import sagemaker_core.shapes as shapes + +# TODO: Can we add custom logic to some of these to set better defaults? +from sagemaker_core.shapes import ( + StoppingCondition, + RetryStrategy, + Channel, + ShuffleConfig, + DataSource, + S3DataSource, + FileSystemDataSource, + TrainingImageConfig, + TrainingRepositoryAuthConfig, + Tag, + InfraCheckConfig, + RemoteDebugConfig, + SessionChainingConfig, + InstanceGroup, + MetricDefinition, +) + +from sagemaker.modules.utils import convert_unassigned_to_none + +__all__ = [ + "SourceCode", + "StoppingCondition", + "RetryStrategy", + "OutputDataConfig", + "Channel", + "ShuffleConfig", + "DataSource", + "S3DataSource", + "FileSystemDataSource", + "TrainingImageConfig", + "TrainingRepositoryAuthConfig", + "Tag", + "InfraCheckConfig", + "RemoteDebugConfig", + "SessionChainingConfig", + "InstanceGroup", + "TensorBoardOutputConfig", + "CheckpointConfig", + "Compute", + "Networking", + "InputData", + "MetricDefinition", +] + + +class BaseConfig(BaseModel): + """BaseConfig""" + + model_config = ConfigDict(validate_assignment=True, extra="forbid") + + +class SourceCode(BaseConfig): + """SourceCode. + + The SourceCode class allows the user to specify the source code location, dependencies, + entry script, or commands to be executed in the training job container. + + Parameters: + source_dir (Optional[str]): + The local directory, s3 uri, or path to tar.gz file stored locally or in s3 that contains + the source code to be used in the training job container. + requirements (Optional[str]): + The path within ``source_dir`` to a ``requirements.txt`` file. If specified, the listed + requirements will be installed in the training job container. + entry_script (Optional[str]): + The path within ``source_dir`` to the entry script that will be executed in the training + job container. If not specified, command must be provided. + command (Optional[str]): + The command(s) to execute in the training job container. Example: "python my_script.py". + If not specified, entry_script must be provided. + ignore_patterns: (Optional[List[str]]) : + The ignore patterns to ignore specific files/folders when uploading to S3. If not specified, + default to: ['.env', '.git', '__pycache__', '.DS_Store', '.cache', '.ipynb_checkpoints']. + """ + + source_dir: Optional[str] = None + requirements: Optional[str] = None + entry_script: Optional[str] = None + command: Optional[str] = None + ignore_patterns: Optional[List[str]] = [ + ".env", + ".git", + "__pycache__", + ".DS_Store", + ".cache", + ".ipynb_checkpoints", + ] + + +class Compute(shapes.ResourceConfig): + """Compute. + + The Compute class is a subclass of ``sagemaker_core.shapes.ResourceConfig`` + and allows the user to specify the compute resources for the training job. + + Parameters: + instance_type (Optional[str]): + The ML compute instance type. For information about available instance types, + see https://aws.amazon.com/sagemaker/pricing/. + instance_count (Optional[int]): The number of ML compute instances to use. For distributed + training, provide a value greater than 1. + volume_size_in_gb (Optional[int]): + The size of the ML storage volume that you want to provision. ML storage volumes store + model artifacts and incremental states. Training algorithms might also use the ML + storage volume for scratch space. Default: 30 + volume_kms_key_id (Optional[str]): + The Amazon Web Services KMS key that SageMaker uses to encrypt data on the storage + volume attached to the ML compute instance(s) that run the training job. + keep_alive_period_in_seconds (Optional[int]): + The duration of time in seconds to retain configured resources in a warm pool for + subsequent training jobs. + instance_groups (Optional[List[InstanceGroup]]): + A list of instance groups for heterogeneous clusters to be used in the training job. + training_plan_arn (Optional[str]): + The Amazon Resource Name (ARN) of the training plan to use for this resource configuration. + enable_managed_spot_training (Optional[bool]): + To train models using managed spot training, choose True. Managed spot training + provides a fully managed and scalable infrastructure for training machine learning + models. this option is useful when training jobs can be interrupted and when there + is flexibility when the training job is run. + """ + + volume_size_in_gb: Optional[int] = 30 + enable_managed_spot_training: Optional[bool] = None + + @model_validator(mode="after") + def _model_validator(self) -> "Compute": + """Convert Unassigned values to None.""" + return convert_unassigned_to_none(self) + + def _to_resource_config(self) -> shapes.ResourceConfig: + """Convert to a sagemaker_core.shapes.ResourceConfig object.""" + compute_config_dict = self.model_dump() + resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys()) + filtered_dict = { + k: v + for k, v in compute_config_dict.items() + if k in resource_config_fields and v is not None + } + if not filtered_dict: + return None + return shapes.ResourceConfig(**filtered_dict) + + +class Networking(shapes.VpcConfig): + """Networking. + + The Networking class is a subclass of ``sagemaker_core.shapes.VpcConfig`` and + allows the user to specify the networking configuration for the training job. + + Parameters: + security_group_ids (Optional[List[str]]): + The VPC security group IDs, in the form sg-xxxxxxxx. Specify the + security groups for the VPC that is specified in the Subnets field. + subnets (Optional[List[str]]): + The ID of the subnets in the VPC to which you want to connect your + training job or model. + enable_network_isolation (Optional[bool]): + Isolates the training container. No inbound or outbound network calls can be made, + except for calls between peers within a training cluster for distributed training. + If you enable network isolation for training jobs that are configured to use a VPC, + SageMaker downloads and uploads customer data and model artifacts through the + specified VPC, but the training container does not have network access. + enable_inter_container_traffic_encryption (Optional[bool]): + To encrypt all communications between ML compute instances in distributed training + choose True. Encryption provides greater security for distributed training, but + training might take longer. How long it takes depends on the amount of + communication between compute instances, especially if you use a deep learning + algorithm in distributed training. + """ + + enable_network_isolation: Optional[bool] = None + enable_inter_container_traffic_encryption: Optional[bool] = None + + @model_validator(mode="after") + def _model_validator(self) -> "Networking": + """Convert Unassigned values to None.""" + return convert_unassigned_to_none(self) + + def _to_vpc_config(self) -> shapes.VpcConfig: + """Convert to a sagemaker_core.shapes.VpcConfig object.""" + compute_config_dict = self.model_dump() + vpc_config_fields = set(shapes.VpcConfig.__annotations__.keys()) + filtered_dict = { + k: v for k, v in compute_config_dict.items() if k in vpc_config_fields and v is not None + } + if not filtered_dict: + return None + return shapes.VpcConfig(**filtered_dict) + + +class InputData(BaseConfig): + """InputData. + + This config allows the user to specify an input data source for the training job. + + Will be found at ``/opt/ml/input/data/`` within the training container. + For convience, can be referenced inside the training container like: + + .. code:: python + + import os + input_data_dir = os.environ['SM_CHANNEL_'] + + Parameters: + channel_name (str): + The name of the input data source channel. + data_source (Union[str, S3DataSource, FileSystemDataSource]): + The data source for the channel. Can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + """ + + channel_name: str = None + data_source: Union[str, FileSystemDataSource, S3DataSource] = None + + +class OutputDataConfig(shapes.OutputDataConfig): + """OutputDataConfig. + + The OutputDataConfig class is a subclass of ``sagemaker_core.shapes.OutputDataConfig`` + and allows the user to specify the output data configuration for the training job. + + Parameters: + s3_output_path (Optional[str]): + The S3 URI where the output data will be stored. This is the location where the + training job will save its output data, such as model artifacts and logs. + kms_key_id (Optional[str]): + The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that + SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side + encryption. + compression_type (Optional[str]): + The model output compression type. Select `NONE` to output an uncompressed model, + recommended for large model outputs. Defaults to `GZIP`. + """ + + s3_output_path: Optional[str] = None + kms_key_id: Optional[str] = None + compression_type: Optional[str] = None + + +class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig): + """TensorBoardOutputConfig. + + The TensorBoardOutputConfig class is a subclass of ``sagemaker_core.shapes.TensorBoardOutputConfig`` + and allows the user to specify the storage locations for the Amazon SageMaker + Debugger TensorBoard. + + Parameters: + s3_output_path (Optional[str]): + Path to Amazon S3 storage location for TensorBoard output. If not specified, will + default to + ``s3://////tensorboard-output`` + local_path (Optional[str]): + Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard. + """ + + s3_output_path: Optional[str] = None + local_path: Optional[str] = "/opt/ml/output/tensorboard" + + +class CheckpointConfig(shapes.CheckpointConfig): + """CheckpointConfig. + + The CheckpointConfig class is a subclass of ``sagemaker_core.shapes.CheckpointConfig`` + and allows the user to specify the checkpoint configuration for the training job. + + Parameters: + s3_uri (Optional[str]): + Path to Amazon S3 storage location for the Checkpoint data. If not specified, will + default to + ``s3://////checkpoints`` + local_path (Optional[str]): + The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints. + """ + + s3_uri: Optional[str] = None + local_path: Optional[str] = "/opt/ml/checkpoints" diff --git a/src/sagemaker/modules/constants.py b/src/sagemaker/modules/constants.py new file mode 100644 index 0000000000..eaf9d131ef --- /dev/null +++ b/src/sagemaker/modules/constants.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Constants module.""" +from __future__ import absolute_import +import os + +DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" + +SM_CODE = "code" +SM_CODE_CONTAINER_PATH = "/opt/ml/input/data/code" + +SM_DRIVERS = "sm_drivers" +SM_DRIVERS_CONTAINER_PATH = "/opt/ml/input/data/sm_drivers" +SM_DRIVERS_LOCAL_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "train/container_drivers" +) + +SM_RECIPE = "recipe" +SM_RECIPE_YAML = "recipe.yaml" +SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}" + +SOURCE_CODE_JSON = "sourcecode.json" +DISTRIBUTED_JSON = "distributed.json" +TRAIN_SCRIPT = "sm_train.sh" + +DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"] +DEFAULT_CONTAINER_ARGUMENTS = [ + "-c", + f"chmod +x {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT} " + + f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}", +] diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py new file mode 100644 index 0000000000..f248b9b77c --- /dev/null +++ b/src/sagemaker/modules/distributed.py @@ -0,0 +1,181 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Distributed module.""" +from __future__ import absolute_import + +import os + +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, List +from sagemaker.modules.utils import safe_serialize +from sagemaker.modules.constants import SM_DRIVERS_LOCAL_PATH +from sagemaker.modules.configs import BaseConfig + + +class SMP(BaseConfig): + """SMP. + + This class is used for configuring the SageMaker Model Parallelism v2 parameters. + For more information on the model parallelism parameters, see: + https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-model-parallel-v2-reference.html#distributed-model-parallel-v2-reference-init-config + + Parameters: + hybrid_shard_degree (Optional[int]): + Specifies a sharded parallelism degree for the model. + sm_activation_offloading (Optional[bool]): + Specifies whether to enable the SMP activation offloading implementation. + activation_loading_horizon (Optional[int]): + An integer specifying the activation offloading horizon type for FSDP. This is the + maximum number of checkpointed or offloaded layers whose inputs can be in the GPU + memory simultaneously. + fsdp_cache_flush_warnings (Optional[bool]): + Detects and warns if cache flushes happen in the PyTorch memory manager, because they + can degrade computational performance. + allow_empty_shards (Optional[bool]): + Whether to allow empty shards when sharding tensors if tensor is not divisible. This is + an experimental fix for crash during checkpointing in certain scenarios. Disabling this + falls back to the original PyTorch behavior. + tensor_parallel_degree (Optional[int]): + Specifies a tensor parallelism degree. The value must be between 1 and world_size. + context_parallel_degree (Optional[int]): + Specifies the context parallelism degree. The value must be between 1 and world_size , + and must be <= hybrid_shard_degree. + expert_parallel_degree (Optional[int]): + Specifies a expert parallelism degree. The value must be between 1 and world_size. + random_seed (Optional[int]): + A seed number for the random operations in distributed modules by SMP tensor + parallelism or expert parallelism. + """ + + hybrid_shard_degree: Optional[int] = None + sm_activation_offloading: Optional[bool] = None + activation_loading_horizon: Optional[int] = None + fsdp_cache_flush_warnings: Optional[bool] = None + allow_empty_shards: Optional[bool] = None + tensor_parallel_degree: Optional[int] = None + context_parallel_degree: Optional[int] = None + expert_parallel_degree: Optional[int] = None + random_seed: Optional[int] = None + + def _to_mp_hyperparameters(self) -> Dict[str, Any]: + """Converts to the hyperparameters format for the SageMaker Model Parallelism v2.""" + mp_parameters = self.model_dump(exclude_none=True) + hyperparameters = { + "mp_parameters": safe_serialize(mp_parameters), + } + return hyperparameters + + +class DistributedConfig(BaseConfig, ABC): + """Abstract base class for distributed training configurations. + + This class defines the interface that all distributed training configurations + must implement. It provides a standardized way to specify driver scripts and + their locations for distributed training jobs. + """ + + @property + @abstractmethod + def driver_dir(self) -> str: + """Directory containing the driver script. + + This property should return the path to the directory containing + the driver script, relative to the container's working directory. + + Returns: + str: Path to directory containing the driver script + """ + + @property + @abstractmethod + def driver_script(self) -> str: + """Name of the driver script. + + This property should return the name of the Python script that implements + the distributed training driver logic. + + Returns: + str: Name of the driver script file + """ + + +class Torchrun(DistributedConfig): + """Torchrun. + + The Torchrun class configures a job that uses ``torchrun`` or + ``torch.distributed.launch`` in the backend to launch distributed training. + + Parameters: + process_count_per_node (int): + The number of processes to run on each node in the training job. + Will default to the number of GPUs available in the container. + smp (Optional[SMP]): + The SageMaker Model Parallelism v2 parameters. + """ + + process_count_per_node: Optional[int] = None + smp: Optional["SMP"] = None + + @property + def driver_dir(self) -> str: + """Directory containing the driver script. + + Returns: + str: Path to directory containing the driver script + """ + return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers") + + @property + def driver_script(self) -> str: + """Name of the driver script. + + Returns: + str: Name of the driver script file + """ + return "torchrun_driver.py" + + +class MPI(DistributedConfig): + """MPI. + + The MPI class configures a job that uses ``mpirun`` in the backend to launch + distributed training. + + Parameters: + process_count_per_node (int): + The number of processes to run on each node in the training job. + Will default to the number of GPUs available in the container. + mpi_additional_options (Optional[str]): + The custom MPI options to use for the training job. + """ + + process_count_per_node: Optional[int] = None + mpi_additional_options: Optional[List[str]] = None + + @property + def driver_dir(self) -> str: + """Directory containing the driver script. + + Returns: + str: Path to directory containing the driver script + """ + return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers") + + @property + def driver_script(self) -> str: + """Name of the driver script. + + Returns: + str: Name of the driver script + """ + return "mpi_driver.py" diff --git a/src/sagemaker/modules/local_core/local_container.py b/src/sagemaker/modules/local_core/local_container.py new file mode 100644 index 0000000000..448330092d --- /dev/null +++ b/src/sagemaker/modules/local_core/local_container.py @@ -0,0 +1,600 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""LocalContainer class module.""" +from __future__ import absolute_import + +import base64 +import os +import re +import shutil +import subprocess +from tempfile import TemporaryDirectory +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, ConfigDict + +from sagemaker.local.image import ( + _Volume, + _aws_credentials, + _check_output, + _pull_image, + _stream_output, + _write_json_file, +) +from sagemaker.local.utils import check_for_studio, recursive_copy +from sagemaker.model import DIR_PARAM_NAME +from sagemaker.modules import logger, Session +from sagemaker.modules.configs import Channel +from sagemaker.utils import ECR_URI_PATTERN, create_tar_file, _module_import_error, download_folder +from sagemaker_core.main.utils import Unassigned +from sagemaker_core.shapes import DataSource + +from six.moves.urllib.parse import urlparse + +STUDIO_HOST_NAME = "sagemaker-local" +DOCKER_COMPOSE_FILENAME = "docker-compose.yaml" +DOCKER_COMPOSE_HTTP_TIMEOUT_ENV = "COMPOSE_HTTP_TIMEOUT" +DOCKER_COMPOSE_HTTP_TIMEOUT = "120" + +REGION_ENV_NAME = "AWS_REGION" +TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME" +S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL" +S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL" +SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE" + + +class _LocalContainer(BaseModel): + """A local training job class for local mode model trainer. + + Attributes: + training_job_name (str): + The name of the training job. + instance_type (str): + The instance type. + instance_count (int): + The number of instances. + image (str): + The image name for training. + container_root (str): + The directory path for the local container root. + input_from_s3 (bool): + If the input is from s3. + is_studio (bool): + If the container is running on SageMaker studio instance. + hosts (Optional[List[str]]): + The list of host names. + input_data_config: Optional[List[Channel]] + The input data channels for the training job. + Takes a list of Channel objects or a dictionary of channel names to DataSourceType. + DataSourceType can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + environment (Optional[Dict[str, str]]): + The environment variables for the training job. + hyper_parameters (Optional[Dict[str, Any]]): + The hyperparameters for the training job. + sagemaker_session (Optional[Session]): + The SageMaker session. + For local mode training, SageMaker session will only be used when input is from S3 or + image needs to be pulled from ECR. + container_entrypoint (Optional[List[str]]): + The command to be executed in the container. + container_arguments (Optional[List[str]]): + The arguments of the container commands. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + + training_job_name: str + instance_type: str + instance_count: int + image: str + container_root: str + input_from_s3: Optional[bool] = False + is_studio: Optional[bool] = False + hosts: Optional[List[str]] = [] + input_data_config: Optional[List[Channel]] + environment: Optional[Dict[str, str]] + hyper_parameters: Optional[Dict[str, str]] + sagemaker_session: Optional[Session] = None + container_entrypoint: Optional[List[str]] + container_arguments: Optional[List[str]] + + _temporary_folders: List[str] = [] + + def model_post_init(self, __context: Any): + """Post init method to perform custom validation and set default values.""" + self.hosts = [f"algo-{i}" for i in range(1, self.instance_count + 1)] + if self.environment is None: + self.environment = {} + if self.hyper_parameters is None: + self.hyper_parameters = {} + + for channel in self.input_data_config: + if channel.data_source and channel.data_source.s3_data_source != Unassigned(): + self.input_from_s3 = True + data_distribution = channel.data_source.s3_data_source.s3_data_distribution_type + if self.sagemaker_session is None: + # In local mode only initiate session when neccessary + self.sagemaker_session = Session() + elif ( + channel.data_source and channel.data_source.file_system_data_source != Unassigned() + ): + self.input_from_s3 = False + data_distribution = channel.data_source.file_system_data_source.file_system_type + else: + raise ValueError( + "Need channel.data_source to have s3_data_source or file_system_data_source" + ) + + supported_distributions = ["FullyReplicated", "EFS"] + if data_distribution and data_distribution not in supported_distributions: + raise RuntimeError( + "Invalid Data Distribution: '{}'. Local mode currently supports FullyReplicated " + "Distribution for S3 data source and EFS Distribution for local data source.".format( + data_distribution, + ) + ) + self.is_studio = check_for_studio() + + def train( + self, + wait: bool, + ) -> str: + """Run a training job locally using docker-compose. + + Args: + wait (bool): + Whether to wait the training output before exiting. + """ + # create output/data folder since sagemaker-containers 2.0 expects it + os.makedirs(os.path.join(self.container_root, "output", "data"), exist_ok=True) + # A shared directory for all the containers. It is only mounted if the training script is + # Local. + os.makedirs(os.path.join(self.container_root, "shared"), exist_ok=True) + + data_dir = os.path.join(self.container_root, "input", "data") + os.makedirs(data_dir, exist_ok=True) + volumes = self._prepare_training_volumes( + data_dir, self.input_data_config, self.hyper_parameters + ) + # If local, source directory needs to be updated to mounted /opt/ml/code path + if DIR_PARAM_NAME in self.hyper_parameters: + src_dir = self.hyper_parameters[DIR_PARAM_NAME] + parsed_uri = urlparse(src_dir) + if parsed_uri.scheme == "file": + self.hyper_parameters[DIR_PARAM_NAME] = "/opt/ml/code" + + for host in self.hosts: + # Create the configuration files + self._create_config_file_directories(host) + self._write_config_files(host, self.input_data_config, self.hyper_parameters) + + self.environment[TRAINING_JOB_NAME_ENV_NAME] = self.training_job_name + if self.input_from_s3: + self.environment[S3_ENDPOINT_URL_ENV_NAME] = ( + self.sagemaker_session.s3_resource.meta.client._endpoint.host + ) + + if self._ecr_login_if_needed(): + _pull_image(self.image) + + if self.sagemaker_session: + self.environment[REGION_ENV_NAME] = self.sagemaker_session.boto_region_name + + compose_data = self._generate_compose_file(self.environment, volumes) + compose_command = self._generate_compose_command(wait) + process = subprocess.Popen( + compose_command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) + + try: + _stream_output(process) + finally: + artifacts = self.retrieve_artifacts(compose_data) + + # Print our Job Complete line + logger.info("Local training job completed, output artifacts saved to %s", artifacts) + + shutil.rmtree(os.path.join(self.container_root, "input")) + shutil.rmtree(os.path.join(self.container_root, "shared")) + for host in self.hosts: + shutil.rmtree(os.path.join(self.container_root, host)) + for folder in self._temporary_folders: + shutil.rmtree(os.path.join(self.container_root, folder)) + return artifacts + + def retrieve_artifacts( + self, + compose_data: dict, + ): + """Get the model artifacts from all the container nodes. + + Used after training completes to gather the data from all the + individual containers. As the official SageMaker Training Service, it + will override duplicate files if multiple containers have the same file + names. + + Args: + compose_data (dict): Docker-Compose configuration in dictionary + format. + + Returns: Local path to the collected model artifacts. + """ + # We need a directory to store the artfiacts from all the nodes + # and another one to contained the compressed final artifacts + artifacts = os.path.join(self.container_root, "artifacts") + compressed_artifacts = os.path.join(self.container_root, "compressed_artifacts") + os.makedirs(artifacts, exist_ok=True) + + model_artifacts = os.path.join(artifacts, "model") + output_artifacts = os.path.join(artifacts, "output") + + artifact_dirs = [model_artifacts, output_artifacts, compressed_artifacts] + for d in artifact_dirs: + os.makedirs(d, exist_ok=True) + + # Gather the artifacts from all nodes into artifacts/model and artifacts/output + for host in self.hosts: + volumes = compose_data["services"][str(host)]["volumes"] + volumes = [v[:-2] if v.endswith(":z") else v for v in volumes] + for volume in volumes: + if re.search(r"^[A-Za-z]:", volume): + unit, host_dir, container_dir = volume.split(":") + host_dir = unit + ":" + host_dir + else: + host_dir, container_dir = volume.split(":") + if container_dir == "/opt/ml/model": + recursive_copy(host_dir, model_artifacts) + elif container_dir == "/opt/ml/output": + recursive_copy(host_dir, output_artifacts) + + # Tar Artifacts -> model.tar.gz and output.tar.gz + model_files = [os.path.join(model_artifacts, name) for name in os.listdir(model_artifacts)] + output_files = [ + os.path.join(output_artifacts, name) for name in os.listdir(output_artifacts) + ] + create_tar_file(model_files, os.path.join(compressed_artifacts, "model.tar.gz")) + create_tar_file(output_files, os.path.join(compressed_artifacts, "output.tar.gz")) + + output_data = "file://%s" % compressed_artifacts + + return os.path.join(output_data, "model.tar.gz") + + def _create_config_file_directories(self, host: str): + """Creates the directories for the config files. + + Args: + host (str): The name of the current host. + """ + for d in ["input", "input/config", "output", "model"]: + os.makedirs(os.path.join(self.container_root, host, d), exist_ok=True) + + def _write_config_files( + self, + host: str, + input_data_config: Optional[List[Channel]], + hyper_parameters: Optional[Dict[str, str]], + ): + """Write the config files for the training containers. + + This method writes the hyper_parameters, resources and input data + configuration files. + + Returns: None + + Args: + host (str): The name of the current host. + input_data_config (List[Channel]): Training input channels to be used for + training. + hyper_parameters (Dict[str, str]): Hyperparameters for training. + """ + config_path = os.path.join(self.container_root, host, "input", "config") + # Only support single container now + resource_config = { + "current_host": host, + "hosts": self.hosts, + "network_interface_name": "ethwe", + "current_instance_type": self.instance_type, + } + + json_input_data_config = {} + for channel in input_data_config: + channel_name = channel.channel_name + json_input_data_config[channel_name] = {"TrainingInputMode": "File"} + if channel.content_type != Unassigned(): + json_input_data_config[channel_name]["ContentType"] = channel.content_type + + _write_json_file(os.path.join(config_path, "hyperparameters.json"), hyper_parameters) + _write_json_file(os.path.join(config_path, "resourceconfig.json"), resource_config) + _write_json_file(os.path.join(config_path, "inputdataconfig.json"), json_input_data_config) + + def _generate_compose_file(self, environment: Dict[str, str], volumes: List[str]) -> dict: + """Writes a config file describing a training/hosting environment. + + This method generates a docker compose configuration file, it has an + entry for each container that will be created (based on self.hosts). it + calls + :meth:~sagemaker.local_session.SageMakerContainer._create_docker_host to + generate the config for each individual container. + + Args: + environment (Dict[str, str]): a dictionary with environment variables to be + passed on to the containers. + volumes (List[str]): a list of volumes that will be mapped to + the containers + + Returns: (dict) A dictionary representation of the configuration that was written. + """ + + if os.environ.get(DOCKER_COMPOSE_HTTP_TIMEOUT_ENV) is None: + os.environ[DOCKER_COMPOSE_HTTP_TIMEOUT_ENV] = DOCKER_COMPOSE_HTTP_TIMEOUT + + services = { + host: self._create_docker_host(host, environment, volumes) for host in self.hosts + } + + if self.is_studio: + content = { + "services": services, + } + else: + content = { + "services": services, + "networks": {"sagemaker-local": {"name": "sagemaker-local"}}, + } + + docker_compose_path = os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME) + + try: + import yaml + except ImportError as e: + logger.error(_module_import_error("yaml", "Local mode", "local")) + raise e + + yaml_content = yaml.dump(content, default_flow_style=False) + with open(docker_compose_path, "w") as f: + f.write(yaml_content) + + return content + + def _create_docker_host( + self, + host: str, + environment: Dict[str, str], + volumes: List[str], + ) -> Dict: + """Creates the docker host configuration. + + Args: + host (str): The host address + environment (Dict[str, str]): a dictionary with environment variables to be + passed on to the containers. + volumes (List[str]): List of volumes that will be mapped to the containers + """ + environment = ["{}={}".format(k, v) for k, v in environment.items()] + aws_creds = None + if self.sagemaker_session: + # In local mode only get aws credentials when neccessary + aws_creds = _aws_credentials(self.sagemaker_session.boto_session) + if aws_creds is not None: + environment.extend(aws_creds) + + if self.is_studio: + environment.extend([f"{SM_STUDIO_LOCAL_MODE}=True"]) + + # Add volumes for the input and output of each host + host_volumes = volumes.copy() + subdirs = ["output", "output/data", "input"] + for subdir in subdirs: + host_dir = os.path.join(self.container_root, host, subdir) + container_dir = "/opt/ml/{}".format(subdir) + volume = _Volume(host_dir, container_dir) + host_volumes.append(volume.map) + + host_config = { + "image": self.image, + "volumes": host_volumes, + "environment": environment, + } + + if self.container_entrypoint: + host_config["entrypoint"] = self.container_entrypoint + if self.container_arguments: + host_config["entrypoint"] = host_config["entrypoint"] + self.container_arguments + + if self.is_studio: + host_config["network_mode"] = "sagemaker" + else: + host_config["networks"] = {"sagemaker-local": {"aliases": [host]}} + + # for GPU support pass in nvidia as the runtime, this is equivalent + # to setting --runtime=nvidia in the docker commandline. + if self.instance_type == "local_gpu": + host_config["deploy"] = { + "resources": { + "reservations": {"devices": [{"count": "all", "capabilities": ["gpu"]}]} + } + } + + return host_config + + def _generate_compose_command(self, wait: bool): + """Invokes the docker compose command. + + Args: + wait (bool): Whether to wait for the docker command result. + """ + _compose_cmd_prefix = self._get_compose_cmd_prefix() + + command = _compose_cmd_prefix + [ + "-f", + os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME), + "up", + "--build", + "--abort-on-container-exit" if wait else "--detach", + ] + + logger.info("docker command: %s", " ".join(command)) + return command + + def _ecr_login_if_needed(self): + """Log into ECR, if needed. + + Only ECR images that not have been pulled locally need login. + """ + sagemaker_pattern = re.compile(ECR_URI_PATTERN) + sagemaker_match = sagemaker_pattern.match(self.image) + if not sagemaker_match: + return False + + # Do we already have the image locally? + if _check_output("docker images -q %s" % self.image).strip(): + return False + + if not self.sagemaker_session: + # In local mode only initiate session when neccessary + self.sagemaker_session = Session() + + ecr = self.sagemaker_session.boto_session.client("ecr") + auth = ecr.get_authorization_token(registryIds=[self.image.split(".")[0]]) + authorization_data = auth["authorizationData"][0] + + raw_token = base64.b64decode(authorization_data["authorizationToken"]) + token = raw_token.decode("utf-8").strip("AWS:") + ecr_url = auth["authorizationData"][0]["proxyEndpoint"] + + # Log in to ecr, but use communicate to not print creds to the console + cmd = f"docker login {ecr_url} -u AWS --password-stdin".split() + proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + ) + + proc.communicate(input=token.encode()) + + return True + + def _prepare_training_volumes( + self, + data_dir: str, + input_data_config: Optional[List[Channel]], + hyper_parameters: Optional[Dict[str, str]], + ) -> List[str]: + """Prepares the training volumes based on input and output data configs. + + Args: + data_dir (str): The directory of input data. + input_data_config (Optional[List[Channel]]): Training input channels to be used for + training. + hyper_parameters (Optional[Dict[str, str]]): Hyperparameters for training. + """ + volumes = [] + model_dir = os.path.join(self.container_root, "model") + volumes.append(_Volume(model_dir, "/opt/ml/model").map) + + # Mount the metadata directory if present. + # Only expected to be present on SM notebook instances. + # This is used by some DeepEngine libraries + metadata_dir = "/opt/ml/metadata" + if os.path.isdir(metadata_dir): + volumes.append(_Volume(metadata_dir, metadata_dir).map) + + # Set up the channels for the containers. For local data we will + # mount the local directory to the container. For S3 Data we will download the S3 data + # first. + for channel in input_data_config: + channel_name = channel.channel_name + channel_dir = os.path.join(data_dir, channel_name) + os.makedirs(channel_dir, exist_ok=True) + + data_source_local_path = self._get_data_source_local_path(channel.data_source) + volumes.append(_Volume(data_source_local_path, channel=channel_name).map) + + # If there is a training script directory and it is a local directory, + # mount it to the container. + if DIR_PARAM_NAME in hyper_parameters: + training_dir = hyper_parameters[DIR_PARAM_NAME] + parsed_uri = urlparse(training_dir) + if parsed_uri.scheme == "file": + host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path) + volumes.append(_Volume(host_dir, "/opt/ml/code").map) + shared_dir = os.path.join(self.container_root, "shared") + volumes.append(_Volume(shared_dir, "/opt/ml/shared").map) + + return volumes + + def _get_data_source_local_path(self, data_source: DataSource): + """Return a local data path of :class:`sagemaker.local.data.DataSource`. + + If the data source is from S3, the data will be downloaded to a temporary + local path. + If the data source is local file, the absolute path will be returned. + + Args: + data_source (DataSource): a data source of local file or s3 + + Returns: + str: The local path of the data. + """ + if data_source.s3_data_source != Unassigned(): + uri = data_source.s3_data_source.s3_uri + parsed_uri = urlparse(uri) + local_dir = TemporaryDirectory(prefix=os.path.join(self.container_root + "/")).name + self._temporary_folders.append(local_dir) + download_folder(parsed_uri.netloc, parsed_uri.path, local_dir, self.sagemaker_session) + return local_dir + else: + return os.path.abspath(data_source.file_system_data_source.directory_path) + + def _get_compose_cmd_prefix(self) -> List[str]: + """Gets the Docker Compose command. + + The method initially looks for 'docker compose' v2 + executable, if not found looks for 'docker-compose' executable. + + Returns: + List[str]: Docker Compose executable split into list. + + Raises: + ImportError: If Docker Compose executable was not found. + """ + compose_cmd_prefix = [] + + output = None + try: + output = subprocess.check_output( + ["docker", "compose", "version"], + stderr=subprocess.DEVNULL, + encoding="UTF-8", + ) + except subprocess.CalledProcessError: + logger.info( + "'Docker Compose' is not installed. " + "Proceeding to check for 'docker-compose' CLI." + ) + + if output and "v2" in output.strip(): + logger.info("'Docker Compose' found using Docker CLI.") + compose_cmd_prefix.extend(["docker", "compose"]) + return compose_cmd_prefix + + if shutil.which("docker-compose") is not None: + logger.info("'Docker Compose' found using Docker Compose CLI.") + compose_cmd_prefix.extend(["docker-compose"]) + return compose_cmd_prefix + + raise ImportError( + "Docker Compose is not installed. " + "Local Mode features will not work without docker compose. " + "For more information on how to install 'docker compose', please, see " + "https://docs.docker.com/compose/install/" + ) diff --git a/src/sagemaker/modules/templates.py b/src/sagemaker/modules/templates.py new file mode 100644 index 0000000000..d888b7bcb9 --- /dev/null +++ b/src/sagemaker/modules/templates.py @@ -0,0 +1,83 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Templates module.""" +from __future__ import absolute_import + +EXECUTE_BASE_COMMANDS = """ +CMD="{base_command}" +echo "Executing command: $CMD" +eval $CMD +""" + +EXECUTE_BASIC_SCRIPT_DRIVER = """ +echo "Running Basic Script driver" +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/basic_script_driver.py +""" + +EXEUCTE_DISTRIBUTED_DRIVER = """ +echo "Running {driver_name} Driver" +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/{driver_script} +""" + +TRAIN_SCRIPT_TEMPLATE = """ +#!/bin/bash +set -e +echo "Starting training script" + +handle_error() {{ + EXIT_STATUS=$? + echo "An error occurred with exit code $EXIT_STATUS" + if [ ! -s /opt/ml/output/failure ]; then + echo "Training Execution failed. For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'. +TrainingJob - $TRAINING_JOB_NAME" >> /opt/ml/output/failure + fi + exit $EXIT_STATUS +}} + +check_python() {{ + SM_PYTHON_CMD=$(command -v python3 || command -v python) + SM_PIP_CMD=$(command -v pip3 || command -v pip) + + # Check if Python is found + if [[ -z "$SM_PYTHON_CMD" || -z "$SM_PIP_CMD" ]]; then + echo "Error: The Python executable was not found in the system path." + return 1 + fi + + return 0 +}} + +trap 'handle_error' ERR + +check_python + +$SM_PYTHON_CMD --version + +echo "/opt/ml/input/config/resourceconfig.json:" +cat /opt/ml/input/config/resourceconfig.json +echo + +echo "/opt/ml/input/config/inputdataconfig.json:" +cat /opt/ml/input/config/inputdataconfig.json +echo + +echo "Setting up environment variables" +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/scripts/environment.py +source /opt/ml/input/sm_training.env + +{working_dir} +{install_requirements} +{execute_driver} + +echo "Training Container Execution Completed" +""" diff --git a/src/sagemaker/modules/train/__init__.py b/src/sagemaker/modules/train/__init__.py new file mode 100644 index 0000000000..51fa17fe04 --- /dev/null +++ b/src/sagemaker/modules/train/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules train directory.""" +from __future__ import absolute_import + +from sagemaker.modules.train.model_trainer import ModelTrainer # noqa: F401 diff --git a/src/sagemaker/modules/train/container_drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/__init__.py new file mode 100644 index 0000000000..864f3663b8 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/common/__init__.py b/src/sagemaker/modules/train/container_drivers/common/__init__.py new file mode 100644 index 0000000000..aab88c6b97 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/common/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers - common directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/common/utils.py b/src/sagemaker/modules/train/container_drivers/common/utils.py new file mode 100644 index 0000000000..a94416550d --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/common/utils.py @@ -0,0 +1,204 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides utility functions for the container drivers.""" +from __future__ import absolute_import + +import os +import logging +import sys +import subprocess +import traceback +import json + +from typing import List, Dict, Any, Tuple, IO, Optional + +# Initialize logger +SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) +logger = logging.getLogger(__name__) +console_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(console_handler) +logger.setLevel(int(SM_LOG_LEVEL)) + +FAILURE_FILE = "/opt/ml/output/failure" +DEFAULT_FAILURE_MESSAGE = """ +Training Execution failed. +For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'. +TrainingJob - {training_job_name} +""" + +USER_CODE_PATH = "/opt/ml/input/data/code" +SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json" +DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json" + +HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json" + +SM_EFA_NCCL_INSTANCES = [ + "ml.g4dn.8xlarge", + "ml.g4dn.12xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.trn1.32xlarge", +] + +SM_EFA_RDMA_INSTANCES = [ + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.trn1.32xlarge", +] + + +def write_failure_file(message: Optional[str] = None): + """Write a failure file with the message.""" + if message is None: + message = DEFAULT_FAILURE_MESSAGE.format(training_job_name=os.environ["TRAINING_JOB_NAME"]) + if not os.path.exists(FAILURE_FILE): + with open(FAILURE_FILE, "w") as f: + f.write(message) + + +def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON): + """Read the source code config json file.""" + try: + with open(source_code_json, "r") as f: + source_code_dict = json.load(f) or {} + except FileNotFoundError: + source_code_dict = {} + return source_code_dict + + +def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON): + """Read the distribution config json file.""" + try: + with open(distributed_json, "r") as f: + distributed_dict = json.load(f) or {} + except FileNotFoundError: + distributed_dict = {} + return distributed_dict + + +def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON): + """Read the hyperparameters config json file.""" + try: + with open(hyperparameters_json, "r") as f: + hyperparameters_dict = json.load(f) or {} + except FileNotFoundError: + hyperparameters_dict = {} + return hyperparameters_dict + + +def get_process_count(process_count: Optional[int] = None) -> int: + """Get the number of processes to run on each node in the training job.""" + return ( + process_count + or int(os.environ.get("SM_NUM_GPUS", 0)) + or int(os.environ.get("SM_NUM_NEURONS", 0)) + or 1 + ) + + +def hyperparameters_to_cli_args(hyperparameters: Dict[str, Any]) -> List[str]: + """Convert the hyperparameters to CLI arguments.""" + cli_args = [] + for key, value in hyperparameters.items(): + value = safe_deserialize(value) + cli_args.extend([f"--{key}", safe_serialize(value)]) + + return cli_args + + +def safe_deserialize(data: Any) -> Any: + """Safely deserialize data from a JSON string. + + This function handles the following cases: + 1. If `data` is not a string, it returns the input as-is. + 3. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`. + 4. If `data` is a string but cannot be decoded as JSON, it returns the original string. + + Returns: + Any: The deserialized data, or the original input if it cannot be JSON-decoded. + """ + if not isinstance(data, str): + return data + try: + return json.loads(data) + except json.JSONDecodeError: + return data + + +def safe_serialize(data): + """Serialize the data without wrapping strings in quotes. + + This function handles the following cases: + 1. If `data` is a string, it returns the string as-is without wrapping in quotes. + 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns + the JSON-encoded string using `json.dumps()`. + 3. If `data` cannot be serialized (e.g., a custom object), it returns the string + representation of the data using `str(data)`. + + Args: + data (Any): The data to serialize. + + Returns: + str: The serialized JSON-compatible string or the string representation of the input. + """ + if isinstance(data, str): + return data + try: + return json.dumps(data) + except TypeError: + return str(data) + + +def get_python_executable() -> str: + """Get the python executable path.""" + return sys.executable + + +def log_subprocess_output(pipe: IO[bytes]): + """Log the output from the subprocess.""" + for line in iter(pipe.readline, b""): + logger.info(line.decode("utf-8").strip()) + + +def execute_commands(commands: List[str]) -> Tuple[int, str]: + """Execute the provided commands and return exit code with failure traceback if any.""" + try: + process = subprocess.Popen( + commands, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + with process.stdout: + log_subprocess_output(process.stdout) + exitcode = process.wait() + if exitcode != 0: + raise subprocess.CalledProcessError(exitcode, commands) + return exitcode, "" + except subprocess.CalledProcessError as e: + # Capture the traceback in case of failure + error_traceback = traceback.format_exc() + print(f"Command failed with exit code {e.returncode}. Traceback: {error_traceback}") + return e.returncode, error_traceback + + +def is_worker_node() -> bool: + """Check if the current node is a worker node.""" + return os.environ.get("SM_CURRENT_HOST") != os.environ.get("SM_MASTER_ADDR") + + +def is_master_node() -> bool: + """Check if the current node is the master node.""" + return os.environ.get("SM_CURRENT_HOST") == os.environ.get("SM_MASTER_ADDR") diff --git a/src/sagemaker/modules/train/container_drivers/distributed_drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/__init__.py new file mode 100644 index 0000000000..a44e7e81a9 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers - drivers directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py new file mode 100644 index 0000000000..0b086a8e4f --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py @@ -0,0 +1,81 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module is the entry point for the Basic Script Driver.""" +from __future__ import absolute_import + +import os +import sys +import json +import shlex + +from pathlib import Path +from typing import List + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + logger, + get_python_executable, + execute_commands, + write_failure_file, + hyperparameters_to_cli_args, +) + + +def create_commands() -> List[str]: + """Create the commands to execute.""" + entry_script = os.environ["SM_ENTRY_SCRIPT"] + hyperparameters = json.loads(os.environ["SM_HPS"]) + python_executable = get_python_executable() + + args = hyperparameters_to_cli_args(hyperparameters) + if entry_script.endswith(".py"): + commands = [python_executable, entry_script] + commands += args + elif entry_script.endswith(".sh"): + args_str = " ".join(shlex.quote(arg) for arg in args) + commands = [ + "/bin/sh", + "-c", + f"chmod +x {entry_script} && ./{entry_script} {args_str}", + ] + else: + raise ValueError( + f"Unsupported entry script type: {entry_script}. Only .py and .sh are supported." + ) + return commands + + +def main(): + """Main function for the Basic Script Driver. + + This function is the entry point for the Basic Script Driver. + + Execution Lifecycle: + 1. Read the source code and hyperparameters JSON files. + 2. Set hyperparameters as command line arguments. + 3. Create the commands to execute. + 4. Execute the commands. + """ + + cmd = create_commands() + + logger.info(f"Executing command: {' '.join(cmd)}") + exit_code, traceback = execute_commands(cmd) + if exit_code != 0: + write_failure_file(traceback) + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py new file mode 100644 index 0000000000..9946272617 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py @@ -0,0 +1,107 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module is the entry point for the MPI driver script.""" +from __future__ import absolute_import + +import os +import sys +import json +from pathlib import Path + +from mpi_utils import ( + start_sshd_daemon, + bootstrap_master_node, + bootstrap_worker_node, + get_mpirun_command, + write_status_file_to_workers, + write_env_vars_to_file, +) + + +sys.path.insert(0, str(Path(__file__).parent.parent)) +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + logger, + hyperparameters_to_cli_args, + get_process_count, + execute_commands, + write_failure_file, +) + + +def main(): + """Main function for the MPI driver script. + + The MPI Dirver is responsible for setting up the MPI environment, + generating the correct mpi commands, and launching the MPI job. + + Execution Lifecycle: + 1. Setup General Environment Variables at /etc/environment + 2. Start SSHD Daemon + 3. Bootstrap Worker Nodes + a. Wait to establish connection with Master Node + b. Wait for Master Node to write status file + 4. Bootstrap Master Node + a. Wait to establish connection with Worker Nodes + b. Generate MPI Command + c. Execute MPI Command with user script provided in `entry_script` + d. Write status file to Worker Nodes + 5. Exit + + """ + entry_script = os.environ["SM_ENTRY_SCRIPT"] + distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + hyperparameters = json.loads(os.environ["SM_HPS"]) + + sm_current_host = os.environ["SM_CURRENT_HOST"] + sm_hosts = json.loads(os.environ["SM_HOSTS"]) + sm_master_addr = os.environ["SM_MASTER_ADDR"] + + write_env_vars_to_file() + start_sshd_daemon() + + if sm_current_host != sm_master_addr: + bootstrap_worker_node(sm_master_addr) + else: + worker_hosts = [host for host in sm_hosts if host != sm_master_addr] + bootstrap_master_node(worker_hosts) + + host_list = json.loads(os.environ["SM_HOSTS"]) + host_count = int(os.environ["SM_HOST_COUNT"]) + process_count = int(distributed_config["process_count_per_node"] or 0) + process_count = get_process_count(process_count) + + if process_count > 1: + host_list = ["{}:{}".format(host, process_count) for host in host_list] + + mpi_command = get_mpirun_command( + host_count=host_count, + host_list=host_list, + num_processes=process_count, + additional_options=distributed_config["mpi_additional_options"] or [], + entry_script_path=entry_script, + ) + + args = hyperparameters_to_cli_args(hyperparameters) + mpi_command += args + + logger.info(f"Executing command: {' '.join(mpi_command)}") + exit_code, error_traceback = execute_commands(mpi_command) + write_status_file_to_workers(worker_hosts) + + if exit_code != 0: + write_failure_file(error_traceback) + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py new file mode 100644 index 0000000000..ec9e1fcef9 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py @@ -0,0 +1,302 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides mpi related utility functions for the container drivers.""" +from __future__ import absolute_import + +import os +import sys +import subprocess +import time + +from pathlib import Path +from typing import List + +import paramiko + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + SM_EFA_NCCL_INSTANCES, + SM_EFA_RDMA_INSTANCES, + get_python_executable, + logger, +) + +FINISHED_STATUS_FILE = "/tmp/done.algo-1" +READY_FILE = "/tmp/ready.%s" +DEFAULT_SSH_PORT = 22 + + +def _write_file_to_host(host: str, status_file: str) -> bool: + """Write the a file to the provided host.""" + try: + logger.info(f"Writing {status_file} to {host}") + subprocess.run( + ["ssh", host, "touch", f"{status_file}"], + capture_output=True, + text=True, + check=True, + ) + logger.info("Finished writing status file") + return True + except subprocess.CalledProcessError: + logger.info(f"Cannot connect to {host}") + return False + + +def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): + """Write the status file to all worker nodes.""" + for worker in worker_hosts: + retry = 0 + while not _write_file_to_host(worker, status_file): + time.sleep(5) + retry += 1 + if retry > 5: + raise TimeoutError(f"Timed out waiting for {worker} to be reachable.") + logger.info(f"Retrying to write status file to {worker}") + + +def _wait_for_status_file(status_file: str): + """Wait for the status file to be created.""" + logger.info(f"Waiting for status file {status_file}") + while not os.path.exists(status_file): + time.sleep(30) + logger.info(f"Found status file {status_file}") + + +def start_sshd_daemon(): + """Start the SSH daemon on the current node.""" + sshd_executable = "/usr/sbin/sshd" + + if not os.path.exists(sshd_executable): + raise RuntimeError("SSH daemon not found.") + + # Start the sshd in daemon mode (-D) + subprocess.Popen([sshd_executable, "-D"]) + logger.info("Started SSH daemon.") + + +class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): + """Class to handle host key policy for SageMaker distributed training SSH connections. + + Example: + >>> client = paramiko.SSHClient() + >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) + >>> # Will succeed for SageMaker algorithm containers + >>> client.connect('algo-1234.internal') + >>> # Will raise SSHException for other unknown hosts + >>> client.connect('unknown-host') # raises SSHException + """ + + def missing_host_key(self, client, hostname, key): + """Accept host keys for algo-* hostnames, reject others. + + Args: + client: The SSHClient instance + hostname: The hostname attempting to connect + key: The host key + + Raises: + paramiko.SSHException: If hostname doesn't match algo-* pattern + """ + if hostname.startswith("algo-"): + client.get_host_keys().add(hostname, key.get_name(), key) + return + raise paramiko.SSHException(f"Unknown host key for {hostname}") + + +def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: + """Check if the connection to the provided host and port is possible.""" + try: + logger.debug("Testing connection to host %s", host) + with paramiko.SSHClient() as client: + client.load_system_host_keys() + client.set_missing_host_key_policy(CustomHostKeyPolicy()) + client.connect(host, port=port) + logger.info("Can connect to host %s", host) + return True + except Exception as e: # pylint: disable=W0703 + logger.info("Cannot connect to host %s", host) + logger.debug(f"Connection failed with exception: {e}") + return False + + +def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Master node waits until it can connect to all worker nodes.""" + start_time = time.time() + if not worker_hosts: + logger.info("No worker nodes to connect to.") + return + + while True: + logger.info("Master is attempting to connect to all workers...") + all_workers_connected = all( + _can_connect(worker, port) and os.path.exists(READY_FILE % worker) + for worker in worker_hosts + ) + + if all_workers_connected: + logger.info("Master can connect to all worker nodes.") + break + if time.time() - start_time > timeout: + raise TimeoutError("Timed out waiting for workers to be reachable.") + + time.sleep(5) # Wait for 5 seconds before trying again + + +def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Worker nodes wait until they can connect to the master node.""" + start_time = time.time() + while True: + logger.info(f"Worker is attempting to connect to the master node {master_host}...") + if _can_connect(master_host, port): + logger.info(f"Worker can connect to master node {master_host}.") + break + if time.time() - start_time > timeout: + raise TimeoutError(f"Timed out waiting for master {master_host} to be reachable.") + + time.sleep(5) # Wait for 5 seconds before trying again + + +def bootstrap_worker_node(master_host: str, status_file: str = FINISHED_STATUS_FILE): + """Bootstrap the worker nodes.""" + logger.info("Bootstrapping worker node...") + _wait_for_master(master_host) + _write_file_to_host(master_host, READY_FILE % os.environ["SM_CURRENT_HOST"]) + _wait_for_status_file(status_file) + + +def bootstrap_master_node(worker_hosts: List[str]): + """Bootstrap the master node.""" + logger.info("Bootstrapping master node...") + _wait_for_workers(worker_hosts) + + +def validate_smddprun() -> bool: + """Whether smddprun is installed. + + Returns: + bool: True if installed + """ + try: + output = subprocess.run( + ["which", "smddprun"], + capture_output=True, + text=True, + check=True, + ) + return output.stdout != "" + except subprocess.CalledProcessError: + return False + + +def validate_smddpmprun() -> bool: + """Whether smddpmprun is installed. + + Returns: + bool: True if both are installed + """ + try: + output = subprocess.run( + ["which", "smddpmprun"], + capture_output=True, + text=True, + check=True, + ) + return output.stdout != "" + except subprocess.CalledProcessError: + return False + + +def write_env_vars_to_file(): + """Write environment variables to /etc/environment file.""" + with open("/etc/environment", "a", encoding="utf-8") as f: + for name in os.environ: + f.write(f"{name}={os.environ.get(name)}\n") + + +def get_mpirun_command( + host_count: int, + host_list: List[str], + num_processes: int, + additional_options: List[str], + entry_script_path: str, +): + """Fetch mpi command""" + network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0") + + mpirun_command = [ + "mpirun", + "--host", + ",".join(host_list), + "-np", + str(num_processes), + "--allow-run-as-root", + "--tag-output", + "-mca", + "btl_tcp_if_include", + network_interface_name, + "-mca", + "oob_tcp_if_include", + network_interface_name, + "-mca", + "plm_rsh_no_tree_spawn", + "1", + "-mca", + "pml", + "ob1", + "-mca", + "btl", + "^openib", + "-mca", + "orte_abort_on_non_zero_status", + "1", + "-mca", + "btl_vader_single_copy_mechanism", + "none", + "-mca", + "plm_rsh_num_concurrent", + str(host_count), + "-x", + "NCCL_SOCKET_IFNAME=%s" % network_interface_name, + "-x", + "LD_LIBRARY_PATH", + "-x", + "PATH", + ] + + if additional_options: + mpirun_command.extend(additional_options) + + instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"] + # EFA settings + if instance_type in SM_EFA_NCCL_INSTANCES: + mpirun_command.extend(["-x", "FI_PROVIDER=efa"]) + # Use simple protocol to handle the out-of-order data delivery from EFA + mpirun_command.extend(["-x", "NCCL_PROTO=simple"]) + + if instance_type in SM_EFA_RDMA_INSTANCES: + # Use EFA's RDMA functionality for one-sided and two-sided transfer + mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"]) + + for credential in [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + ]: + if credential in os.environ: + mpirun_command.extend(["-x", credential]) + + mpirun_command.extend([get_python_executable()]) + mpirun_command.extend(["-m", "mpi4py", entry_script_path]) + return mpirun_command diff --git a/src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py b/src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py new file mode 100644 index 0000000000..7fcfabe05d --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module is the entry point for the Torchrun driver script.""" +from __future__ import absolute_import + +import os +import sys +import json + +from pathlib import Path +from typing import List, Tuple + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + logger, + hyperparameters_to_cli_args, + get_process_count, + get_python_executable, + execute_commands, + write_failure_file, + SM_EFA_NCCL_INSTANCES, + SM_EFA_RDMA_INSTANCES, +) + + +def pytorch_version() -> Tuple[int, int]: + """Get the PyTorch version as a tuple of integers.""" + import torch + + return tuple(map(int, torch.__version__.split(".")[:2])) + + +def get_base_pytorch_command() -> List[str]: + """Get the base Torch Distributed launcher to execute""" + if pytorch_version() >= (1, 9): + return ["torchrun"] + return [f"{get_python_executable()}", "-m", "torch.distributed.launch"] + + +def setup_env(): + """Setup the environment variables for PyTorch distributed training""" + instance_type = os.environ["SM_CURRENT_INSTANCE_TYPE"] + network_interface_name = os.environ.get("SM_NETWORK_INTERFACE_NAME", "eth0") + if instance_type in SM_EFA_NCCL_INSTANCES: + # Enable EFA use + os.environ["FI_PROVIDER"] = "efa" + if instance_type in SM_EFA_RDMA_INSTANCES: + # Use EFA's RDMA functionality for one-sided and two-sided transfer + os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1" + os.environ["RDMAV_FORK_SAFE"] = "1" + os.environ["NCCL_SOCKET_IFNAME"] = str(network_interface_name) + os.environ["NCCL_PROTO"] = "simple" + + +def create_commands(): + """Create the Torch Distributed command to execute""" + entry_script = os.environ["SM_ENTRY_SCRIPT"] + distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + hyperparameters = json.loads(os.environ["SM_HPS"]) + + process_count = int(distributed_config["process_count_per_node"] or 0) + process_count = get_process_count(process_count) + host_count = int(os.environ["SM_HOST_COUNT"]) + + torch_cmd = [] + if os.environ.get("RUN_NEURON_PARALLEL_COMPILE") == "1": + torch_cmd.append("neuron_parallel_compile") + + torch_cmd.extend(get_base_pytorch_command()) + torch_cmd.extend( + [ + f"--nnodes={host_count}", + f"--nproc_per_node={process_count}", + ] + ) + + # If more than one node is used, add node rank information + if int(host_count) > 1: + torch_cmd.extend( + [ + f"--master_addr={os.environ['SM_MASTER_ADDR']}", + f"--master_port={os.environ['SM_MASTER_PORT']}", + f"--node_rank={os.environ['SM_CURRENT_HOST_RANK']}", + ] + ) + + torch_cmd.extend([entry_script]) + + args = hyperparameters_to_cli_args(hyperparameters) + torch_cmd += args + + return torch_cmd + + +def main(): + """Main function to execute the PyTorch distributed training script. + + This function sets some environment variables and executes the PyTorch + distributed training script. + + Execution Lifecycle: + 1. Setup Environment Variables for PyTorch Distributed Training + 2. Create Torch Distributed Command + 3. Execute Torch Distributed Command with user script provided in `entry_script` + 4. Exit + + """ + setup_env() + torch_cmd = create_commands() + logger.info(f"Executing command: {' '.join(torch_cmd)}") + exit_code, traceback = execute_commands(torch_cmd) + if exit_code != 0: + write_failure_file(traceback) + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/src/sagemaker/modules/train/container_drivers/scripts/__init__.py b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py new file mode 100644 index 0000000000..f04c5b17a0 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers - scripts directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/scripts/environment.py b/src/sagemaker/modules/train/container_drivers/scripts/environment.py new file mode 100644 index 0000000000..897b1f8af4 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/scripts/environment.py @@ -0,0 +1,305 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module is used to define the environment variables for the training job container.""" +from __future__ import absolute_import + +from typing import Dict, Any +import multiprocessing +import subprocess +import json +import os +import sys +from pathlib import Path +import logging + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + safe_serialize, + safe_deserialize, + read_distributed_json, + read_source_code_json, +) + +# Initialize logger +SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) +logger = logging.getLogger(__name__) +console_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(console_handler) +logger.setLevel(int(SM_LOG_LEVEL)) + +SM_MODEL_DIR = "/opt/ml/model" + +SM_INPUT_DIR = "/opt/ml/input" +SM_INPUT_DATA_DIR = "/opt/ml/input/data" +SM_INPUT_CONFIG_DIR = "/opt/ml/input/config" + +SM_OUTPUT_DIR = "/opt/ml/output" +SM_OUTPUT_FAILURE = "/opt/ml/output/failure" +SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" +SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code" +SM_DISTRIBUTED_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/distributed_drivers" + +SM_MASTER_ADDR = "algo-1" +SM_MASTER_PORT = 7777 + +RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json" +INPUT_DATA_CONFIG = f"{SM_INPUT_CONFIG_DIR}/inputdataconfig.json" +HYPERPARAMETERS_CONFIG = f"{SM_INPUT_CONFIG_DIR}/hyperparameters.json" + +ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env" + +SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"] +HIDDEN_VALUE = "******" + + +def num_cpus() -> int: + """Return the number of CPUs available in the current container. + + Returns: + int: Number of CPUs available in the current container. + """ + return multiprocessing.cpu_count() + + +def num_gpus() -> int: + """Return the number of GPUs available in the current container. + + Returns: + int: Number of GPUs available in the current container. + """ + try: + cmd = ["nvidia-smi", "--list-gpus"] + output = subprocess.check_output(cmd).decode("utf-8") + return sum(1 for line in output.splitlines() if line.startswith("GPU ")) + except (OSError, subprocess.CalledProcessError): + logger.info("No GPUs detected (normal if no gpus installed)") + return 0 + + +def num_neurons() -> int: + """Return the number of neuron cores available in the current container. + + Returns: + int: Number of Neuron Cores available in the current container. + """ + try: + cmd = ["neuron-ls", "-j"] + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") + j = json.loads(output) + neuron_cores = 0 + for item in j: + neuron_cores += item.get("nc_count", 0) + logger.info("Found %s neurons on this instance", neuron_cores) + return neuron_cores + except OSError: + logger.info("No Neurons detected (normal if no neurons installed)") + return 0 + except subprocess.CalledProcessError as e: + if e.output is not None: + try: + msg = e.output.decode("utf-8").partition("error=")[2] + logger.info( + "No Neurons detected (normal if no neurons installed). \ + If neuron installed then %s", + msg, + ) + except AttributeError: + logger.info("No Neurons detected (normal if no neurons installed)") + else: + logger.info("No Neurons detected (normal if no neurons installed)") + + return 0 + + +def deserialize_hyperparameters(hyperparameters: Dict[str, str]) -> Dict[str, Any]: + """Deserialize hyperparameters from string to their original types. + + Args: + hyperparameters (Dict[str, str]): Hyperparameters as strings. + + Returns: + Dict[str, Any]: Hyperparameters as their original types. + """ + deserialized_hyperparameters = {} + for key, value in hyperparameters.items(): + deserialized_hyperparameters[key] = safe_deserialize(value) + return deserialized_hyperparameters + + +def set_env( + resource_config: Dict[str, Any], + input_data_config: Dict[str, Any], + hyperparameters_config: Dict[str, Any], + output_file: str = ENV_OUTPUT_FILE, +): + """Set environment variables for the training job container. + + Args: + resource_config (Dict[str, Any]): Resource configuration for the training job. + input_data_config (Dict[str, Any]): Input data configuration for the training job. + hyperparameters_config (Dict[str, Any]): Hyperparameters configuration for the training job. + output_file (str): Output file to write the environment variables. + """ + # Constants + env_vars = { + "SM_MODEL_DIR": SM_MODEL_DIR, + "SM_INPUT_DIR": SM_INPUT_DIR, + "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, + "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, + "SM_OUTPUT_DIR": SM_OUTPUT_DIR, + "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, + "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, + "SM_LOG_LEVEL": SM_LOG_LEVEL, + "SM_MASTER_ADDR": SM_MASTER_ADDR, + "SM_MASTER_PORT": SM_MASTER_PORT, + } + + # SourceCode and DistributedConfig Environment Variables + source_code = read_source_code_json() + if source_code: + env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH + env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "") + + distributed = read_distributed_json() + if distributed: + env_vars["SM_DISTRIBUTED_DRIVER_DIR"] = SM_DISTRIBUTED_DRIVER_DIR_PATH + env_vars["SM_DISTRIBUTED_CONFIG"] = distributed + + # Data Channels + channels = list(input_data_config.keys()) + for channel in channels: + env_vars[f"SM_CHANNEL_{channel.upper()}"] = f"{SM_INPUT_DATA_DIR}/{channel}" + env_vars["SM_CHANNELS"] = channels + + # Hyperparameters + hps = deserialize_hyperparameters(hyperparameters_config) + for key, value in hps.items(): + key_upper = key.replace("-", "_").upper() + env_vars[f"SM_HP_{key_upper}"] = value + env_vars["SM_HPS"] = hps + + # Host Variables + current_host = resource_config["current_host"] + current_instance_type = resource_config["current_instance_type"] + hosts = resource_config["hosts"] + sorted_hosts = sorted(hosts) + + env_vars["SM_CURRENT_HOST"] = current_host + env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type + env_vars["SM_HOSTS"] = sorted_hosts + env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] + env_vars["SM_HOST_COUNT"] = len(sorted_hosts) + env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) + + env_vars["SM_NUM_CPUS"] = num_cpus() + env_vars["SM_NUM_GPUS"] = num_gpus() + env_vars["SM_NUM_NEURONS"] = num_neurons() + + # Misc. + env_vars["SM_RESOURCE_CONFIG"] = resource_config + env_vars["SM_INPUT_DATA_CONFIG"] = input_data_config + + # All Training Environment Variables + env_vars["SM_TRAINING_ENV"] = { + "channel_input_dirs": { + channel: env_vars[f"SM_CHANNEL_{channel.upper()}"] for channel in channels + }, + "current_host": env_vars["SM_CURRENT_HOST"], + "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], + "hosts": env_vars["SM_HOSTS"], + "master_addr": env_vars["SM_MASTER_ADDR"], + "master_port": env_vars["SM_MASTER_PORT"], + "hyperparameters": env_vars["SM_HPS"], + "input_data_config": input_data_config, + "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], + "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], + "input_dir": env_vars["SM_INPUT_DIR"], + "job_name": os.environ["TRAINING_JOB_NAME"], + "log_level": env_vars["SM_LOG_LEVEL"], + "model_dir": env_vars["SM_MODEL_DIR"], + "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], + "num_cpus": env_vars["SM_NUM_CPUS"], + "num_gpus": env_vars["SM_NUM_GPUS"], + "num_neurons": env_vars["SM_NUM_NEURONS"], + "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], + "resource_config": env_vars["SM_RESOURCE_CONFIG"], + } + with open(output_file, "w") as f: + for key, value in env_vars.items(): + f.write(f"export {key}='{safe_serialize(value)}'\n") + + logger.info("Environment Variables:") + log_env_variables(env_vars_dict=env_vars) + + +def mask_sensitive_info(data): + """Recursively mask sensitive information in a dictionary.""" + if isinstance(data, dict): + for k, v in data.items(): + if isinstance(v, dict): + data[k] = mask_sensitive_info(v) + elif isinstance(v, str) and any( + keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS + ): + data[k] = HIDDEN_VALUE + return data + + +def log_key_value(key: str, value: str): + """Log a key-value pair, masking sensitive values if necessary.""" + if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): + logger.info("%s=%s", key, HIDDEN_VALUE) + elif isinstance(value, dict): + masked_value = mask_sensitive_info(value) + logger.info("%s=%s", key, json.dumps(masked_value)) + else: + try: + decoded_value = json.loads(value) + if isinstance(decoded_value, dict): + masked_value = mask_sensitive_info(decoded_value) + logger.info("%s=%s", key, json.dumps(masked_value)) + else: + logger.info("%s=%s", key, decoded_value) + except (json.JSONDecodeError, TypeError): + logger.info("%s=%s", key, value) + + +def log_env_variables(env_vars_dict: Dict[str, Any]): + """Log Environment Variables from the environment and an env_vars_dict.""" + for key, value in os.environ.items(): + log_key_value(key, value) + + for key, value in env_vars_dict.items(): + log_key_value(key, value) + + +def main(): + """Main function to set the environment variables for the training job container.""" + with open(RESOURCE_CONFIG, "r") as f: + resource_config = json.load(f) + with open(INPUT_DATA_CONFIG, "r") as f: + input_data_config = json.load(f) + with open(HYPERPARAMETERS_CONFIG, "r") as f: + hyperparameters_config = json.load(f) + + set_env( + resource_config=resource_config, + input_data_config=input_data_config, + hyperparameters_config=hyperparameters_config, + output_file=ENV_OUTPUT_FILE, + ) + + +if __name__ == "__main__": + main() diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py new file mode 100644 index 0000000000..828c5da198 --- /dev/null +++ b/src/sagemaker/modules/train/model_trainer.py @@ -0,0 +1,1437 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""ModelTrainer class module.""" +from __future__ import absolute_import + +from enum import Enum +import os +import json +import shutil +from tempfile import TemporaryDirectory +from typing import Optional, List, Union, Dict, Any, ClassVar +import yaml + +from graphene.utils.str_converters import to_camel_case, to_snake_case + +from sagemaker_core.main import resources +from sagemaker_core.resources import TrainingJob +from sagemaker_core import shapes +from sagemaker_core.shapes import AlgorithmSpecification +from sagemaker_core.main.utils import serialize + +from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call + +from sagemaker.config.config_schema import ( + _simple_path, + SAGEMAKER, + MODEL_TRAINER, + MODULES, + PYTHON_SDK, + TRAINING_JOB_ENVIRONMENT_PATH, + TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + TRAINING_JOB_VPC_CONFIG_PATH, + TRAINING_JOB_SUBNETS_PATH, + TRAINING_JOB_SECURITY_GROUP_IDS_PATH, + TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH, + TRAINING_JOB_RESOURCE_CONFIG_PATH, + TRAINING_JOB_ROLE_ARN_PATH, + TRAINING_JOB_TAGS_PATH, +) + +from sagemaker.utils import resolve_value_from_config +from sagemaker.modules import Session, get_execution_role +from sagemaker.modules import configs +from sagemaker.modules.configs import ( + Compute, + StoppingCondition, + RetryStrategy, + SourceCode, + TrainingImageConfig, + Channel, + DataSource, + S3DataSource, + FileSystemDataSource, + Networking, + Tag, + InfraCheckConfig, + RemoteDebugConfig, + SessionChainingConfig, + InputData, + MetricDefinition, +) + +from sagemaker.modules.local_core.local_container import _LocalContainer +from sagemaker.modules.distributed import Torchrun, DistributedConfig +from sagemaker.modules.utils import ( + _get_repo_name_from_image, + _get_unique_name, + _is_valid_path, + _is_valid_s3_uri, + safe_serialize, +) +from sagemaker.modules.types import DataSourceType +from sagemaker.modules.constants import ( + DEFAULT_INSTANCE_TYPE, + SM_CODE, + SM_CODE_CONTAINER_PATH, + SM_DRIVERS, + SM_DRIVERS_LOCAL_PATH, + SM_RECIPE, + SM_RECIPE_YAML, + SM_RECIPE_CONTAINER_PATH, + TRAIN_SCRIPT, + DEFAULT_CONTAINER_ENTRYPOINT, + DEFAULT_CONTAINER_ARGUMENTS, + SOURCE_CODE_JSON, + DISTRIBUTED_JSON, +) +from sagemaker.modules.templates import ( + TRAIN_SCRIPT_TEMPLATE, + EXECUTE_BASE_COMMANDS, + EXEUCTE_DISTRIBUTED_DRIVER, + EXECUTE_BASIC_SCRIPT_DRIVER, +) +from sagemaker.telemetry.telemetry_logging import _telemetry_emitter +from sagemaker.telemetry.constants import Feature +from sagemaker.modules import logger +from sagemaker.modules.train.sm_recipes.utils import ( + _get_args_from_recipe, + _determine_device_type, + _is_nova_recipe, + _load_base_recipe, +) + + +class Mode(Enum): + """Enum class for training mode.""" + + LOCAL_CONTAINER = "LOCAL_CONTAINER" + SAGEMAKER_TRAINING_JOB = "SAGEMAKER_TRAINING_JOB" + + +class ModelTrainer(BaseModel): + """Class that trains a model using AWS SageMaker. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import SourceCode, Compute, InputData + + ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data'] + source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns) + training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image" + model_trainer = ModelTrainer( + training_image=training_image, + source_code=source_code, + ) + + train_data = InputData(channel_name="train", data_source="s3://bucket/train") + model_trainer.train(input_data_config=[train_data]) + + training_job = model_trainer._latest_training_job + + Parameters: + training_mode (Mode): + The training mode. Valid values are "Mode.LOCAL_CONTAINER" or + "Mode.SAGEMAKER_TRAINING_JOB". + sagemaker_session (Optiona(Session)): + The SageMakerCore session. For convinience, can be imported like: + ``from sagemaker.modules import Session``. + If not specified, a new session will be created. + If the default bucket for the artifacts needs to be updated, it can be done by + passing it in the Session object. + role (Optional(str)): + The IAM role ARN for the training job. + If not specified, the default SageMaker execution role will be used. + base_job_name (Optional[str]): + The base name for the training job. + If not specified, a default name will be generated using the algorithm name + or training image. + source_code (Optional[SourceCode]): + The source code configuration. This is used to configure the source code for + running the training job. + distributed (Optional[DistributedConfig]): + The distributed runner for the training job. This is used to configure + a distributed training job. If specifed, ``source_code`` must also + be provided. + compute (Optional[Compute]): + The compute configuration. This is used to specify the compute resources for + the training job. If not specified, will default to 1 instance of ml.m5.xlarge. + networking (Optional[Networking]): + The networking configuration. This is used to specify the networking settings + for the training job. + stopping_condition (Optional[StoppingCondition]): + The stopping condition. This is used to specify the different stopping + conditions for the training job. + If not specified, will default to 1 hour max run time. + algorithm_name (Optional[str]): + The SageMaker marketplace algorithm name/arn to use for the training job. + algorithm_name cannot be specified if training_image is specified. + training_image (Optional[str]): + The training image URI to use for the training job container. + training_image cannot be specified if algorithm_name is specified. + To find available sagemaker distributed images, + see: https://docs.aws.amazon.com/sagemaker/latest/dg-ecr-paths/sagemaker-algo-docker-registry-paths + training_image_config (Optional[TrainingImageConfig]): + Training image Config. This is the configuration to use an image from a private + Docker registry for a training job. + output_data_config (Optional[OutputDataConfig]): + The output data configuration. This is used to specify the output data location + for the training job. + If not specified in the session, will default to + ``s3://///``. + input_data_config (Optional[List[Union[Channel, InputData]]]): + The input data config for the training job. + Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI + string, local file path string, S3DataSource object, or FileSystemDataSource object. + checkpoint_config (Optional[CheckpointConfig]): + Contains information about the output location for managed spot training checkpoint + data. + training_input_mode (Optional[str]): + The input mode for the training job. Valid values are "Pipe", "File", "FastFile". + Defaults to "File". + environment (Optional[Dict[str, str]]): + The environment variables for the training job. + hyperparameters (Optional[Union[Dict[str, Any], str]): + The hyperparameters for the training job. Can be a dictionary of hyperparameters + or a path to hyperparameters json/yaml file. + tags (Optional[List[Tag]]): + An array of key-value pairs. You can use tags to categorize your AWS resources + in different ways, for example, by purpose, owner, or environment. + local_container_root (Optional[str]): + The local root directory to store artifacts from a training job launched in + "LOCAL_CONTAINER" mode. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, validate_assignment=True, extra="forbid" + ) + + training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB + sagemaker_session: Optional[Session] = None + role: Optional[str] = None + base_job_name: Optional[str] = None + source_code: Optional[SourceCode] = None + distributed: Optional[DistributedConfig] = None + compute: Optional[Compute] = None + networking: Optional[Networking] = None + stopping_condition: Optional[StoppingCondition] = None + training_image: Optional[str] = None + training_image_config: Optional[TrainingImageConfig] = None + algorithm_name: Optional[str] = None + output_data_config: Optional[shapes.OutputDataConfig] = None + input_data_config: Optional[List[Union[Channel, InputData]]] = None + checkpoint_config: Optional[shapes.CheckpointConfig] = None + training_input_mode: Optional[str] = "File" + environment: Optional[Dict[str, str]] = {} + hyperparameters: Optional[Union[Dict[str, Any], str]] = {} + tags: Optional[List[Tag]] = None + local_container_root: Optional[str] = os.getcwd() + + # Created Artifacts + _latest_training_job: Optional[resources.TrainingJob] = PrivateAttr(default=None) + + # Private TrainingJob Parameters + _tensorboard_output_config: Optional[shapes.TensorBoardOutputConfig] = PrivateAttr(default=None) + _retry_strategy: Optional[RetryStrategy] = PrivateAttr(default=None) + _infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None) + _session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None) + _remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None) + _metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None) + + _is_nova_recipe: Optional[bool] = PrivateAttr(default=None) + _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) + _temp_code_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) + + CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = [ + "role", + "base_job_name", + "source_code", + "compute", + "networking", + "stopping_condition", + "training_image", + "training_image_config", + "algorithm_name", + "output_data_config", + "checkpoint_config", + "training_input_mode", + "environment", + "hyperparameters", + ] + + SERIALIZABLE_CONFIG_ATTRIBUTES: ClassVar[Any] = { + "source_code": SourceCode, + "compute": Compute, + "networking": Networking, + "stopping_condition": StoppingCondition, + "training_image_config": TrainingImageConfig, + "output_data_config": configs.OutputDataConfig, + "checkpoint_config": configs.CheckpointConfig, + } + + def _populate_intelligent_defaults(self): + """Function to populate all the possible default configs + + Model Trainer specific configs take precedence over the generic training job ones. + """ + self._populate_intelligent_defaults_from_model_trainer_space() + self._populate_intelligent_defaults_from_training_job_space() + + def _populate_intelligent_defaults_from_training_job_space(self): + """Function to populate all the possible default configs from Training Job Space""" + if not self.environment: + self.environment = resolve_value_from_config( + config_path=TRAINING_JOB_ENVIRONMENT_PATH, sagemaker_session=self.sagemaker_session + ) + + default_enable_network_isolation = resolve_value_from_config( + config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + sagemaker_session=self.sagemaker_session, + ) + default_vpc_config = resolve_value_from_config( + config_path=TRAINING_JOB_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session + ) + + if not self.networking: + if default_enable_network_isolation is not None or default_vpc_config is not None: + self.networking = Networking( + default_enable_network_isolation=default_enable_network_isolation, + subnets=resolve_value_from_config(config_path=TRAINING_JOB_SUBNETS_PATH), + security_group_ids=resolve_value_from_config( + config_path=TRAINING_JOB_SECURITY_GROUP_IDS_PATH + ), + ) + else: + if self.networking.enable_network_isolation is None: + self.networking.enable_network_isolation = default_enable_network_isolation + if self.networking.subnets is None: + self.networking.subnets = resolve_value_from_config( + config_path=TRAINING_JOB_SUBNETS_PATH + ) + if self.networking.security_group_ids is None: + self.networking.subnets = resolve_value_from_config( + config_path=TRAINING_JOB_SUBNETS_PATH + ) + + if not self.output_data_config: + default_output_data_config = resolve_value_from_config( + config_path=TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH + ) + if default_output_data_config: + self.output_data_config = configs.OutputDataConfig( + **self._convert_keys_to_snake(default_output_data_config) + ) + + if not self.compute: + default_resource_config = resolve_value_from_config( + config_path=TRAINING_JOB_RESOURCE_CONFIG_PATH + ) + if default_resource_config: + self.compute = Compute(**self._convert_keys_to_snake(default_resource_config)) + + if not self.role: + self.role = resolve_value_from_config(config_path=TRAINING_JOB_ROLE_ARN_PATH) + + if not self.tags: + self.tags = resolve_value_from_config(config_path=TRAINING_JOB_TAGS_PATH) + + def _convert_keys_to_snake(self, config: dict) -> dict: + """Utility helper function that converts the keys of a dictionary into snake case""" + return {to_snake_case(key): value for key, value in config.items()} + + def _populate_intelligent_defaults_from_model_trainer_space(self): + """Function to populate all the possible default configs from Model Trainer Space""" + + for configurable_attribute in self.CONFIGURABLE_ATTRIBUTES: + if getattr(self, configurable_attribute) is None: + default_config = resolve_value_from_config( + config_path=_simple_path( + SAGEMAKER, + PYTHON_SDK, + MODULES, + MODEL_TRAINER, + to_camel_case(configurable_attribute), + ), + sagemaker_session=self.sagemaker_session, + ) + if default_config is not None: + if configurable_attribute in self.SERIALIZABLE_CONFIG_ATTRIBUTES: + default_config = self.SERIALIZABLE_CONFIG_ATTRIBUTES.get( + configurable_attribute + )( + **default_config # pylint: disable=E1134 + ) + setattr(self, configurable_attribute, default_config) + + def __del__(self): + """Destructor method to clean up the temporary directory.""" + # Clean up the temporary directory if it exists and class was initialized + if hasattr(self, "__pydantic_fields_set__"): + if self._temp_recipe_train_dir is not None: + self._temp_recipe_train_dir.cleanup() + if self._temp_code_dir is not None: + self._temp_code_dir.cleanup() + + def _validate_training_image_and_algorithm_name( + self, training_image: Optional[str], algorithm_name: Optional[str] + ): + """Validate that only one of 'training_image' or 'algorithm_name' is provided.""" + if not training_image and not algorithm_name: + raise ValueError( + "Atleast one of 'training_image' or 'algorithm_name' must be provided.", + ) + if training_image and algorithm_name: + raise ValueError( + "Only one of 'training_image' or 'algorithm_name' must be provided.", + ) + + def _validate_distributed_config( + self, + source_code: Optional[SourceCode], + distributed: Optional[DistributedConfig], + ): + """Validate the distribution configuration.""" + if distributed and not source_code.entry_script: + raise ValueError( + "Must provide 'entry_script' if 'distribution' " + "is provided in 'source_code'.", + ) + + # TODO: Move to use pydantic model validators + def _validate_source_code(self, source_code: Optional[SourceCode]): + """Validate the source code configuration.""" + if source_code: + if source_code.requirements or source_code.entry_script: + source_dir = source_code.source_dir + requirements = source_code.requirements + entry_script = source_code.entry_script + if not source_dir: + raise ValueError( + "If 'requirements' or 'entry_script' is provided in 'source_code', " + + "'source_dir' must also be provided.", + ) + if not ( + _is_valid_path(source_dir, path_type="Directory") + or _is_valid_s3_uri(source_dir, path_type="Directory") + or ( + _is_valid_path(source_dir, path_type="File") + and source_dir.endswith(".tar.gz") + ) + or ( + _is_valid_s3_uri(source_dir, path_type="File") + and source_dir.endswith(".tar.gz") + ) + ): + raise ValueError( + f"Invalid 'source_dir' path: {source_dir}. " + + "Must be a valid local directory, " + "s3 uri or path to tar.gz file stored locally or in s3.", + ) + if requirements: + if not source_dir.endswith(".tar.gz"): + if not _is_valid_path( + f"{source_dir}/{requirements}", path_type="File" + ) and not _is_valid_s3_uri( + f"{source_dir}/{requirements}", path_type="File" + ): + raise ValueError( + f"Invalid 'requirements': {requirements}. " + + "Must be a valid file within the 'source_dir'.", + ) + if entry_script: + if not source_dir.endswith(".tar.gz"): + if not _is_valid_path( + f"{source_dir}/{entry_script}", path_type="File" + ) and not _is_valid_s3_uri( + f"{source_dir}/{entry_script}", path_type="File" + ): + raise ValueError( + f"Invalid 'entry_script': {entry_script}. " + + "Must be a valid file within the 'source_dir'.", + ) + + @staticmethod + def _validate_and_load_hyperparameters_file(hyperparameters_file: str) -> Dict[str, Any]: + """Validate the hyperparameters file.""" + if not os.path.exists(hyperparameters_file): + raise ValueError(f"Hyperparameters file not found: {hyperparameters_file}") + logger.info(f"Loading hyperparameters from file: {hyperparameters_file}") + with open(hyperparameters_file, "r") as f: + contents = f.read() + try: + hyperparameters = json.loads(contents) + logger.debug("Hyperparameters loaded as JSON") + return hyperparameters + except json.JSONDecodeError: + try: + logger.info(f"contents: {contents}") + hyperparameters = yaml.safe_load(contents) + if not isinstance(hyperparameters, dict): + raise ValueError("YAML contents must be a valid mapping") + logger.info(f"hyperparameters: {hyperparameters}") + logger.debug("Hyperparameters loaded as YAML") + return hyperparameters + except (yaml.YAMLError, ValueError): + raise ValueError( + f"Invalid hyperparameters file: {hyperparameters_file}. " + "Must be a valid JSON or YAML file." + ) + + def model_post_init(self, __context: Any): + """Post init method to perform custom validation and set default values.""" + self._validate_training_image_and_algorithm_name(self.training_image, self.algorithm_name) + self._validate_source_code(self.source_code) + self._validate_distributed_config(self.source_code, self.distributed) + + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: + if self.sagemaker_session is None: + self.sagemaker_session = Session() + logger.warning("SageMaker session not provided. Using default Session.") + + if self.role is None: + self.role = get_execution_role(sagemaker_session=self.sagemaker_session) + logger.warning(f"Role not provided. Using default role:\n{self.role}") + + if self.base_job_name is None: + if self.algorithm_name: + self.base_job_name = f"{self.algorithm_name}-job" + elif self.training_image: + self.base_job_name = f"{_get_repo_name_from_image(self.training_image)}-job" + logger.warning(f"Base name not provided. Using default name:\n{self.base_job_name}") + + if self.compute is None: + self.compute = Compute( + instance_type=DEFAULT_INSTANCE_TYPE, + instance_count=1, + volume_size_in_gb=30, + ) + logger.warning(f"Compute not provided. Using default:\n{self.compute}") + + if self.compute.instance_type is None: + self.compute.instance_type = DEFAULT_INSTANCE_TYPE + logger.warning(f"Instance type not provided. Using default:\n{DEFAULT_INSTANCE_TYPE}") + if self.compute.instance_count is None: + self.compute.instance_count = 1 + logger.warning( + f"Instance count not provided. Using default:\n{self.compute.instance_count}" + ) + if self.compute.volume_size_in_gb is None: + self.compute.volume_size_in_gb = 30 + logger.warning( + f"Volume size not provided. Using default:\n{self.compute.volume_size_in_gb}" + ) + + if self.stopping_condition is None: + self.stopping_condition = StoppingCondition( + max_runtime_in_seconds=3600, + max_pending_time_in_seconds=None, + max_wait_time_in_seconds=None, + ) + logger.warning( + f"StoppingCondition not provided. Using default:\n{self.stopping_condition}" + ) + if self.stopping_condition.max_runtime_in_seconds is None: + self.stopping_condition.max_runtime_in_seconds = 3600 + logger.info( + "Max runtime not provided. Using default:\n" + f"{self.stopping_condition.max_runtime_in_seconds}" + ) + + if self.hyperparameters and isinstance(self.hyperparameters, str): + self.hyperparameters = self._validate_and_load_hyperparameters_file( + self.hyperparameters + ) + + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: + if self.output_data_config is None: + session = self.sagemaker_session + base_job_name = self.base_job_name + self.output_data_config = configs.OutputDataConfig( + s3_output_path=f"s3://{self._fetch_bucket_name_and_prefix(session)}" + f"/{base_job_name}", + compression_type="GZIP", + kms_key_id=None, + ) + logger.warning( + f"OutputDataConfig not provided. Using default:\n{self.output_data_config}" + ) + if self.output_data_config.s3_output_path is None: + session = self.sagemaker_session + base_job_name = self.base_job_name + self.output_data_config.s3_output_path = ( + f"s3://{self._fetch_bucket_name_and_prefix(session)}/{base_job_name}" + ) + logger.warning( + f"OutputDataConfig s3_output_path not provided. Using default:\n" + f"{self.output_data_config.s3_output_path}" + ) + if self.output_data_config.compression_type is None: + self.output_data_config.compression_type = "GZIP" + logger.warning( + f"OutputDataConfig compression type not provided. Using default:\n" + f"{self.output_data_config.compression_type}" + ) + + if self.training_image: + logger.info(f"Training image URI: {self.training_image}") + + @staticmethod + def _fetch_bucket_name_and_prefix(session: Session) -> str: + """Helper function to get the bucket name with the corresponding prefix if applicable""" + if session.default_bucket_prefix is not None: + return f"{session.default_bucket()}/{session.default_bucket_prefix}" + return session.default_bucket() + + def _create_training_job_args( + self, + input_data_config: Optional[List[Union[Channel, InputData]]] = None, + boto3: bool = False, + ) -> Dict[str, Any]: + """Create the training job arguments. + + Args: + input_data_config (Optional[List[Union[Channel, InputData]]]): + input_data_config (Optional[List[Union[Channel, InputData]]]): + The input data config for the training job. + Takes a list of Channel objects or a dictionary of channel names to DataSourceType. + DataSourceType can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + boto3 (bool): Whether to return the arguments in boto3 format. Defaults to False. + By default, the arguments are returned in the format used by the SageMaker Core. + + Returns: + Dict[str, Any]: The training job arguments. + """ + self._populate_intelligent_defaults() + current_training_job_name = _get_unique_name(self.base_job_name) + input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input" + + final_input_data_config = self.input_data_config.copy() if self.input_data_config else [] + + if input_data_config: + # merge the inputs with method parameter taking precedence + existing_channels = {input.channel_name: input for input in final_input_data_config} + new_channels = [] + for new_input in input_data_config: + if new_input.channel_name in existing_channels: + existing_channels[new_input.channel_name] = new_input + else: + new_channels.append(new_input) + + final_input_data_config = list(existing_channels.values()) + new_channels + + if self._is_nova_recipe: + for input_data in final_input_data_config: + if input_data.channel_name == SM_RECIPE: + raise ValueError( + "Cannot use reserved channel name 'recipe' as an input channel name " + " for Nova Recipe" + ) + recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML) + recipe_channel = self.create_input_data_channel( + channel_name=SM_RECIPE, + data_source=recipe_file_path, + key_prefix=input_data_key_prefix, + ) + final_input_data_config.append(recipe_channel) + self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH}) + + if final_input_data_config: + final_input_data_config = self._get_input_data_config( + final_input_data_config, input_data_key_prefix + ) + + if self.checkpoint_config and not self.checkpoint_config.s3_uri: + self.checkpoint_config.s3_uri = ( + f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/" + f"{self.base_job_name}/{current_training_job_name}/checkpoints" + ) + if self._tensorboard_output_config and not self._tensorboard_output_config.s3_output_path: + self._tensorboard_output_config.s3_output_path = ( + f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/" + f"{self.base_job_name}" + ) + + string_hyper_parameters = {} + if self.hyperparameters: + for hyper_parameter, value in self.hyperparameters.items(): + string_hyper_parameters[hyper_parameter] = safe_serialize(value) + + container_entrypoint = None + container_arguments = None + if self.source_code: + if self.training_mode == Mode.LOCAL_CONTAINER: + self._temp_code_dir = TemporaryDirectory( + prefix=os.path.join(self.local_container_root + "/") + ) + else: + self._temp_code_dir = TemporaryDirectory() + # Copy everything under container_drivers/ to a temporary directory + shutil.copytree(SM_DRIVERS_LOCAL_PATH, self._temp_code_dir.name, dirs_exist_ok=True) + + # If distributed is provided, overwrite code under /drivers + if self.distributed: + distributed_driver_dir = self.distributed.driver_dir + driver_dir = os.path.join(self._temp_code_dir.name, "distributed_drivers") + shutil.copytree(distributed_driver_dir, driver_dir, dirs_exist_ok=True) + + # If source code is provided, create a channel for the source code + # The source code will be mounted at /opt/ml/input/data/code in the container + if self.source_code.source_dir: + source_code_channel = self.create_input_data_channel( + channel_name=SM_CODE, + data_source=self.source_code.source_dir, + key_prefix=input_data_key_prefix, + ignore_patterns=self.source_code.ignore_patterns, + ) + final_input_data_config.append(source_code_channel) + + self._prepare_train_script( + tmp_dir=self._temp_code_dir, + source_code=self.source_code, + distributed=self.distributed, + ) + + if isinstance(self.distributed, Torchrun) and self.distributed.smp: + mp_parameters = self.distributed.smp._to_mp_hyperparameters() + string_hyper_parameters.update(mp_parameters) + + self._write_source_code_json(tmp_dir=self._temp_code_dir, source_code=self.source_code) + self._write_distributed_json(tmp_dir=self._temp_code_dir, distributed=self.distributed) + + # Create an input channel for drivers packaged by the sdk + sm_drivers_channel = self.create_input_data_channel( + channel_name=SM_DRIVERS, + data_source=self._temp_code_dir.name, + key_prefix=input_data_key_prefix, + ignore_patterns=self.source_code.ignore_patterns, + ) + final_input_data_config.append(sm_drivers_channel) + + # If source_code is provided, we will always use + # the default container entrypoint and arguments + # to execute the sm_train.sh script. + # Any commands generated from the source_code will be + # executed from the sm_train.sh script. + container_entrypoint = DEFAULT_CONTAINER_ENTRYPOINT + container_arguments = DEFAULT_CONTAINER_ARGUMENTS + + algorithm_specification = AlgorithmSpecification( + algorithm_name=self.algorithm_name, + training_image=self.training_image, + training_input_mode=self.training_input_mode, + training_image_config=self.training_image_config, + container_entrypoint=container_entrypoint, + container_arguments=container_arguments, + metric_definitions=self._metric_definitions, + ) + + resource_config = self.compute._to_resource_config() + vpc_config = self.networking._to_vpc_config() if self.networking else None + + if boto3: + args = {} + args["TrainingJobName"] = current_training_job_name + args["AlgorithmSpecification"] = algorithm_specification + args["HyperParameters"] = string_hyper_parameters + args["InputDataConfig"] = final_input_data_config + args["ResourceConfig"] = resource_config + args["VpcConfig"] = vpc_config + args["RoleArn"] = self.role + args["Tags"] = self.tags + args["StoppingCondition"] = self.stopping_condition + args["OutputDataConfig"] = self.output_data_config + args["CheckpointConfig"] = self.checkpoint_config + args["Environment"] = self.environment + args["EnableManagedSotTraining"] = self.compute.enable_managed_spot_training + args["EnableInterContainerTrafficEncryption"] = ( + self.networking.enable_inter_container_traffic_encryption + if self.networking + else None + ) + args["EnableNetworkIsolation"] = ( + self.networking.enable_network_isolation if self.networking else None + ) + args["RemoteDebugConfig"] = self._remote_debug_config + args["TensorBoardOutputConfig"] = self._tensorboard_output_config + args["RetryStrategy"] = self._retry_strategy + args["InfraCheckConfig"] = self._infra_check_config + args["SessionChainingConfig"] = self._session_chaining_config + return serialize(args) + else: + args = {} + args["training_job_name"] = current_training_job_name + args["algorithm_specification"] = algorithm_specification + args["hyper_parameters"] = string_hyper_parameters + args["input_data_config"] = final_input_data_config + args["resource_config"] = resource_config + args["vpc_config"] = vpc_config + args["session"] = self.sagemaker_session.boto_session + args["role_arn"] = self.role + args["tags"] = self.tags + args["stopping_condition"] = self.stopping_condition + args["output_data_config"] = self.output_data_config + args["checkpoint_config"] = self.checkpoint_config + args["environment"] = self.environment + args["enable_managed_spot_training"] = self.compute.enable_managed_spot_training + args["enable_inter_container_traffic_encryption"] = ( + self.networking.enable_inter_container_traffic_encryption + if self.networking + else None + ) + args["enable_network_isolation"] = ( + self.networking.enable_network_isolation if self.networking else None + ) + args["remote_debug_config"] = self._remote_debug_config + args["tensor_board_output_config"] = self._tensorboard_output_config + args["retry_strategy"] = self._retry_strategy + args["infra_check_config"] = self._infra_check_config + args["session_chaining_config"] = self._session_chaining_config + return args + + @_telemetry_emitter(feature=Feature.MODEL_TRAINER, func_name="model_trainer.train") + @validate_call + def train( + self, + input_data_config: Optional[List[Union[Channel, InputData]]] = None, + wait: Optional[bool] = True, + logs: Optional[bool] = True, + ): + """Train a model using AWS SageMaker. + + Args: + input_data_config (Optional[List[Union[Channel, InputData]]]): + The input data config for the training job. + Takes a list of Channel objects or a dictionary of channel names to DataSourceType. + DataSourceType can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + wait (Optional[bool]): + Whether to wait for the training job to complete before returning. + Defaults to True. + logs (Optional[bool]): + Whether to display the training container logs while training. + Defaults to True. + """ + args = self._create_training_job_args(input_data_config=input_data_config) + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: + training_job = TrainingJob.create(**args) + self._latest_training_job = training_job + if wait: + training_job.wait(logs=logs) + if logs and not wait: + logger.warning( + "Not displaing the training container logs as 'wait' is set to False." + ) + else: + local_container = _LocalContainer( + training_job_name=args["training_job_name"], + instance_type=args["resource_config"].instance_type, + instance_count=args["resource_config"].instance_count, + image=args["algorithm_specification"].training_image, + container_root=self.local_container_root, + sagemaker_session=self.sagemaker_session, + container_entrypoint=args["algorithm_specification"].container_entrypoint, + container_arguments=args["algorithm_specification"].container_arguments, + input_data_config=args["input_data_config"], + hyper_parameters=args["hyper_parameters"], + environment=args["environment"], + ) + local_container.train(wait) + if self._temp_code_dir is not None: + self._temp_code_dir.cleanup() + + def create_input_data_channel( + self, + channel_name: str, + data_source: DataSourceType, + key_prefix: Optional[str] = None, + ignore_patterns: Optional[List[str]] = None, + ) -> Channel: + """Create an input data channel for the training job. + + Args: + channel_name (str): The name of the input data channel. + data_source (DataSourceType): The data source for the input data channel. + DataSourceType can be an S3 URI string, local file path string, + S3DataSource object, or FileSystemDataSource object. + key_prefix (Optional[str]): The key prefix to use when uploading data to S3. + Only applicable when data_source is a local file path string. + If not specified, local data will be uploaded to: + ``s3:////input//`` + + If specified, local data will be uploaded to: + ``s3://///`` + ignore_patterns: (Optional[List[str]]) : + The ignore patterns to ignore specific files/folders when uploading to S3. + If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store', + '.cache', '.ipynb_checkpoints']. + """ + channel = None + if isinstance(data_source, str): + if _is_valid_s3_uri(data_source): + channel = Channel( + channel_name=channel_name, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=data_source, + s3_data_distribution_type="FullyReplicated", + ), + ), + input_mode="File", + ) + if key_prefix: + logger.warning( + "key_prefix is only applicable when data_source is a local file path." + ) + elif _is_valid_path(data_source): + if self.training_mode == Mode.LOCAL_CONTAINER: + channel = Channel( + channel_name=channel_name, + data_source=DataSource( + file_system_data_source=FileSystemDataSource.model_construct( + directory_path=data_source, + file_system_type="EFS", + ), + ), + input_mode="File", + ) + else: + key_prefix = ( + f"{key_prefix}/{channel_name}" + if key_prefix + else f"{self.base_job_name}/input/{channel_name}" + ) + if self.sagemaker_session.default_bucket_prefix: + key_prefix = f"{self.sagemaker_session.default_bucket_prefix}/{key_prefix}" + if ignore_patterns and _is_valid_path(data_source, path_type="Directory"): + tmp_dir = TemporaryDirectory() + copied_path = os.path.join( + tmp_dir.name, os.path.basename(os.path.normpath(data_source)) + ) + shutil.copytree( + data_source, + copied_path, + dirs_exist_ok=True, + ignore=shutil.ignore_patterns(*ignore_patterns), + ) + s3_uri = self.sagemaker_session.upload_data( + path=copied_path, + bucket=self.sagemaker_session.default_bucket(), + key_prefix=key_prefix, + ) + else: + s3_uri = self.sagemaker_session.upload_data( + path=data_source, + bucket=self.sagemaker_session.default_bucket(), + key_prefix=key_prefix, + ) + channel = Channel( + channel_name=channel_name, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=s3_uri, + s3_data_distribution_type="FullyReplicated", + ), + ), + input_mode="File", + ) + else: + raise ValueError(f"Not a valid S3 URI or local file path: {data_source}.") + elif isinstance(data_source, S3DataSource): + channel = Channel( + channel_name=channel_name, data_source=DataSource(s3_data_source=data_source) + ) + elif isinstance(data_source, FileSystemDataSource): + channel = Channel( + channel_name=channel_name, + data_source=DataSource(file_system_data_source=data_source), + ) + return channel + + def _get_input_data_config( + self, + input_data_channels: Optional[List[Union[Channel, InputData]]], + key_prefix: Optional[str] = None, + ) -> List[Channel]: + """Get the input data configuration for the training job. + + Args: + input_data_channels (Optional[List[Union[Channel, InputData]]]): + The input data config for the training job. + Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI + string, local file path string, S3DataSource object, or FileSystemDataSource object. + """ + if input_data_channels is None: + return [] + + channels = [] + for input_data in input_data_channels: + if isinstance(input_data, Channel): + channels.append(input_data) + elif isinstance(input_data, InputData): + channel = self.create_input_data_channel( + input_data.channel_name, + input_data.data_source, + key_prefix=key_prefix, + ) + channels.append(channel) + else: + raise ValueError( + f"Invalid input data channel: {input_data}. " + + "Must be a Channel or InputDataSource." + ) + return channels + + def _write_source_code_json(self, tmp_dir: TemporaryDirectory, source_code: SourceCode): + """Write the source code configuration to a JSON file.""" + file_path = os.path.join(tmp_dir.name, SOURCE_CODE_JSON) + with open(file_path, "w") as f: + dump = source_code.model_dump() if source_code else {} + f.write(json.dumps(dump)) + + def _write_distributed_json( + self, + tmp_dir: TemporaryDirectory, + distributed: Optional[DistributedConfig] = None, + ): + """Write the distributed runner configuration to a JSON file.""" + file_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON) + with open(file_path, "w") as f: + dump = distributed.model_dump() if distributed else {} + f.write(json.dumps(dump)) + + def _prepare_train_script( + self, + tmp_dir: TemporaryDirectory, + source_code: SourceCode, + distributed: Optional[DistributedConfig] = None, + ): + """Prepare the training script to be executed in the training job container. + + Args: + source_code (SourceCode): The source code configuration. + """ + + base_command = "" + if source_code.command: + if source_code.entry_script: + logger.warning( + "Both 'command' and 'entry_script' are provided in the SourceCode. " + + "Defaulting to 'command'." + ) + base_command = source_code.command.split() + base_command = " ".join(base_command) + + install_requirements = "" + if source_code.requirements: + install_requirements = ( + "echo 'Installing requirements'\n" + + f"$SM_PIP_CMD install -r {source_code.requirements}" + ) + + working_dir = "" + if source_code.source_dir: + working_dir = f"cd {SM_CODE_CONTAINER_PATH} \n" + if source_code.source_dir.endswith(".tar.gz"): + tarfile_name = os.path.basename(source_code.source_dir) + working_dir += f"tar -xzf {tarfile_name} \n" + + if base_command: + execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command) + elif distributed: + execute_driver = EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name=distributed.__class__.__name__, + driver_script=distributed.driver_script, + ) + elif source_code.entry_script and not source_code.command and not distributed: + if not source_code.entry_script.endswith((".py", ".sh")): + raise ValueError( + f"Unsupported entry script: {source_code.entry_script}." + + "Only .py and .sh scripts are supported." + ) + execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER + else: + # This should never be reached, as the source_code should have been validated. + raise ValueError( + f"Unsupported SourceCode or DistributedConfig: {source_code}, {distributed}." + + "Please provide a valid configuration with atleast one of 'command'" + + " or entry_script'." + ) + + train_script = TRAIN_SCRIPT_TEMPLATE.format( + working_dir=working_dir, + install_requirements=install_requirements, + execute_driver=execute_driver, + ) + + with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w") as f: + f.write(train_script) + + @classmethod + def from_recipe( + cls, + training_recipe: str, + compute: Compute, + recipe_overrides: Optional[Dict[str, Any]] = None, + networking: Optional[Networking] = None, + stopping_condition: Optional[StoppingCondition] = None, + requirements: Optional[str] = None, + training_image: Optional[str] = None, + training_image_config: Optional[TrainingImageConfig] = None, + output_data_config: Optional[shapes.OutputDataConfig] = None, + input_data_config: Optional[List[Union[Channel, InputData]]] = None, + checkpoint_config: Optional[shapes.CheckpointConfig] = None, + training_input_mode: Optional[str] = "File", + environment: Optional[Dict[str, str]] = None, + hyperparameters: Optional[Union[Dict[str, Any], str]] = {}, + tags: Optional[List[Tag]] = None, + sagemaker_session: Optional[Session] = None, + role: Optional[str] = None, + base_job_name: Optional[str] = None, + ) -> "ModelTrainer": # noqa: D412 + """Create a ModelTrainer from a training recipe. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import Compute + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "model": { + "data": { + "use_synthetic_data": True + } + } + } + + compute = Compute( + instance_type="ml.p5.48xlarge", + keep_alive_period_in_seconds=3600 + ) + + model_trainer = ModelTrainer.from_recipe( + training_recipe="fine-tuning/deepseek/hf_deepseek_r1_distilled_llama_8b_seq8k_gpu_fine_tuning", + recipe_overrides=recipe_overrides, + compute=compute, + ) + + model_trainer.train(wait=False) + + + Args: + training_recipe (str): + The training recipe to use for training the model. This must be the name of + a sagemaker training recipe or a path to a local training recipe .yaml file. + For available training recipes, see: https://github.com/aws/sagemaker-hyperpod-recipes/ + compute (Compute): + The compute configuration. This is used to specify the compute resources for + the training job. If not specified, will default to 1 instance of ml.m5.xlarge. + recipe_overrides (Optional[Dict[str, Any]]): + The recipe overrides. This is used to override the default recipe parameters. + networking (Optional[Networking]): + The networking configuration. This is used to specify the networking settings + for the training job. + stopping_condition (Optional[StoppingCondition]): + The stopping condition. This is used to specify the different stopping + conditions for the training job. + If not specified, will default to 1 hour max run time. + requirements (Optional[str]): + The path to a requirements file to install in the training job container. + training_image (Optional[str]): + The training image URI to use for the training job container. If not specified, + the training image will be determined from the recipe. + training_image_config (Optional[TrainingImageConfig]): + Training image Config. This is the configuration to use an image from a private + Docker registry for a training job. + output_data_config (Optional[OutputDataConfig]): + The output data configuration. This is used to specify the output data location + for the training job. + If not specified, will default to ``s3:////output/``. + input_data_config (Optional[List[Union[Channel, InputData]]]): + The input data config for the training job. + Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI + string, local file path string, S3DataSource object, or FileSystemDataSource object. + checkpoint_config (Optional[CheckpointConfig]): + Contains information about the output location for managed spot training checkpoint + data. + training_input_mode (Optional[str]): + The input mode for the training job. Valid values are "Pipe", "File", "FastFile". + Defaults to "File". + environment (Optional[Dict[str, str]]): + The environment variables for the training job. + tags (Optional[List[Tag]]): + An array of key-value pairs. You can use tags to categorize your AWS resources + in different ways, for example, by purpose, owner, or environment. + sagemaker_session (Optional[Session]): + The SageMakerCore session. + If not specified, a new session will be created. + role (Optional[str]): + The IAM role ARN for the training job. + If not specified, the default SageMaker execution role will be used. + base_job_name (Optional[str]): + The base name for the training job. + If not specified, a default name will be generated using the algorithm name + or training image. + """ + if compute.instance_type is None: + raise ValueError( + "Must set ``instance_type`` in ``compute`` input when using training recipes." + ) + device_type = _determine_device_type(compute.instance_type) + recipe = _load_base_recipe( + training_recipe=training_recipe, recipe_overrides=recipe_overrides + ) + is_nova = _is_nova_recipe(recipe=recipe) + + if device_type == "cpu" and not is_nova: + raise ValueError( + "Training recipe is not supported for CPU instances. " + + "Please provide a GPU or Tranium instance type." + ) + if training_image is None and is_nova: + raise ValueError("training_image must be provided when using recipe for Nova.") + + if training_image_config and training_image is None: + raise ValueError("training_image must be provided when using training_image_config.") + + if sagemaker_session is None: + sagemaker_session = Session() + logger.warning("SageMaker session not provided. Using default Session.") + if role is None: + role = get_execution_role(sagemaker_session=sagemaker_session) + logger.warning(f"Role not provided. Using default role:\n{role}") + + # The training recipe is used to prepare the following args: + # - source_code + # - training_image + # - distributed + # - compute + # - hyperparameters + model_trainer_args, tmp_dir = _get_args_from_recipe( + training_recipe=recipe, + recipe_overrides=recipe_overrides, + requirements=requirements, + compute=compute, + region_name=sagemaker_session.boto_region_name, + role=role, + ) + if training_image is not None: + model_trainer_args["training_image"] = training_image + if hyperparameters and not is_nova: + logger.warning( + "Hyperparameters are not supported for general training recipes. " + + "Ignoring hyperparameters input." + ) + if is_nova: + if hyperparameters and isinstance(hyperparameters, str): + hyperparameters = cls._validate_and_load_hyperparameters_file(hyperparameters) + model_trainer_args["hyperparameters"].update(hyperparameters) + elif hyperparameters and isinstance(hyperparameters, dict): + model_trainer_args["hyperparameters"].update(hyperparameters) + + model_trainer = cls( + sagemaker_session=sagemaker_session, + role=role, + base_job_name=base_job_name, + networking=networking, + stopping_condition=stopping_condition, + training_image_config=training_image_config, + output_data_config=output_data_config, + input_data_config=input_data_config, + checkpoint_config=checkpoint_config, + training_input_mode=training_input_mode, + environment=environment, + tags=tags, + **model_trainer_args, + ) + model_trainer._is_nova_recipe = is_nova + model_trainer._temp_recipe_train_dir = tmp_dir + return model_trainer + + def with_tensorboard_output_config( + self, tensorboard_output_config: Optional[shapes.TensorBoardOutputConfig] = None + ) -> "ModelTrainer": # noqa: D412 + """Set the TensorBoard output configuration. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + + model_trainer = ModelTrainer( + ... + ).with_tensorboard_output_config() + + Args: + tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig): + The TensorBoard output configuration. + """ + self._tensorboard_output_config = ( + tensorboard_output_config or configs.TensorBoardOutputConfig() + ) + return self + + def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": # noqa: D412 + """Set the retry strategy for the training job. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import RetryStrategy + + retry_strategy = RetryStrategy(maximum_retry_attempts=3) + + model_trainer = ModelTrainer( + ... + ).with_retry_strategy(retry_strategy) + + Args: + retry_strategy (sagemaker.modules.configs.RetryStrategy): + The retry strategy for the training job. + """ + self._retry_strategy = retry_strategy + return self + + def with_infra_check_config( + self, infra_check_config: Optional[InfraCheckConfig] = None + ) -> "ModelTrainer": # noqa: D412 + """Set the infra check configuration for the training job. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + + model_trainer = ModelTrainer( + ... + ).with_infra_check_config() + + Args: + infra_check_config (sagemaker.modules.configs.InfraCheckConfig): + The infra check configuration for the training job. + """ + self._infra_check_config = infra_check_config or InfraCheckConfig(enable_infra_check=True) + return self + + def with_session_chaining_config( + self, session_chaining_config: Optional[SessionChainingConfig] = None + ) -> "ModelTrainer": # noqa: D412 + """Set the session chaining configuration for the training job. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + + model_trainer = ModelTrainer( + ... + ).with_session_chaining_config() + + Args: + session_chaining_config (sagemaker.modules.configs.SessionChainingConfig): + The session chaining configuration for the training job. + """ + self._session_chaining_config = session_chaining_config or SessionChainingConfig( + enable_session_tag_chaining=True + ) + return self + + def with_remote_debug_config( + self, remote_debug_config: Optional[RemoteDebugConfig] = None + ) -> "ModelTrainer": # noqa: D412 + """Set the remote debug configuration for the training job. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + + model_trainer = ModelTrainer( + ... + ).with_remote_debug_config() + + Args: + remote_debug_config (sagemaker.modules.configs.RemoteDebugConfig): + The remote debug configuration for the training job. + """ + self._remote_debug_config = remote_debug_config or RemoteDebugConfig( + enable_remote_debug=True + ) + return self + + def with_checkpoint_config( + self, checkpoint_config: Optional[shapes.CheckpointConfig] = None + ) -> "ModelTrainer": # noqa: D412 + """Set the checkpoint configuration for the training job. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + + model_trainer = ModelTrainer( + ... + ).with_checkpoint_config() + + Args: + checkpoint_config (sagemaker.modules.configs.CheckpointConfig): + The checkpoint configuration for the training job. + """ + self.checkpoint_config = checkpoint_config or configs.CheckpointConfig() + return self + + def with_metric_definitions( + self, metric_definitions: List[MetricDefinition] + ) -> "ModelTrainer": # noqa: D412 + """Set the metric definitions for the training job. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import MetricDefinition + + metric_definitions = [ + MetricDefinition( + name="loss", + regex="Loss: (.*?)", + ) + ] + + model_trainer = ModelTrainer( + ... + ).with_metric_definitions(metric_definitions) + + Args: + metric_definitions (List[MetricDefinition]): + The metric definitions for the training job. + """ + self._metric_definitions = metric_definitions + return self diff --git a/src/sagemaker/modules/train/sm_recipes/__init__.py b/src/sagemaker/modules/train/sm_recipes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/modules/train/sm_recipes/training_recipes.json b/src/sagemaker/modules/train/sm_recipes/training_recipes.json new file mode 100644 index 0000000000..a51513f49f --- /dev/null +++ b/src/sagemaker/modules/train/sm_recipes/training_recipes.json @@ -0,0 +1,17 @@ +{ + "adapter_repo": "https://github.com/aws/sagemaker-training-adapter-for-nemo.git", + "launcher_repo": "https://github.com/aws/sagemaker-hyperpod-recipes.git", + "neuron_dist_repo": "https://github.com/aws-neuron/neuronx-distributed-training.git", + "gpu_image" : { + "framework": "pytorch-smp", + "version": "2.4.1", + "additional_args": { + "container_version": "cu121" + } + }, + "neuron_image": { + "framework": "hyperpod-recipes-neuron", + "version": "2.1.2", + "additional_args": {} + } +} \ No newline at end of file diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py new file mode 100644 index 0000000000..b6523e14dd --- /dev/null +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -0,0 +1,438 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utility functions for SageMaker training recipes.""" +from __future__ import absolute_import + +import math +import os +import json +import shutil +import tempfile +from urllib.request import urlretrieve +from typing import Dict, Any, Optional, Tuple, Union + +import omegaconf +from omegaconf import OmegaConf, dictconfig, DictConfig + +from sagemaker.image_uris import retrieve + +from sagemaker.modules import logger +from sagemaker.modules.utils import _run_clone_command_silent +from sagemaker.modules.constants import SM_RECIPE_YAML +from sagemaker.modules.configs import Compute, SourceCode +from sagemaker.modules.distributed import Torchrun, SMP + + +def _try_resolve_recipe(recipe: DictConfig, key=None) -> DictConfig: + """Try to resolve recipe and return resolved recipe.""" + if key is not None: + recipe = dictconfig.DictConfig({key: recipe}) + try: + OmegaConf.resolve(recipe) + except omegaconf.errors.OmegaConfBaseException: + return None + if key is None: + return recipe + return recipe[key] + + +def _determine_device_type(instance_type: str) -> str: + """Determine device type (gpu, cpu, trainium) based on instance type.""" + instance_family = instance_type.split(".")[1] + if instance_family.startswith(("p", "g")): + return "gpu" + if instance_family.startswith("trn"): + return "trainium" + return "cpu" + + +def _load_recipes_cfg() -> str: + """Load training recipes configuration json.""" + training_recipes_cfg_filename = os.path.join(os.path.dirname(__file__), "training_recipes.json") + with open(training_recipes_cfg_filename) as training_recipes_cfg_file: + training_recipes_cfg = json.load(training_recipes_cfg_file) + return training_recipes_cfg + + +def _load_base_recipe( + training_recipe: str, + recipe_overrides: Optional[Dict[str, Any]] = None, + training_recipes_cfg: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Load recipe and apply overrides.""" + if recipe_overrides is None: + recipe_overrides = dict() + + temp_local_recipe = tempfile.NamedTemporaryFile(prefix="recipe_original", suffix=".yaml").name + + if training_recipe.endswith(".yaml"): + if os.path.isfile(training_recipe): + shutil.copy(training_recipe, temp_local_recipe) + else: + try: + urlretrieve(training_recipe, temp_local_recipe) + except Exception as e: + raise ValueError( + f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}" + ) + else: + recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_") + if training_recipes_cfg is None: + training_recipes_cfg = _load_recipes_cfg() + + launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get( + "launcher_repo" + ) + _run_clone_command_silent(launcher_repo, recipe_launcher_dir.name) + + recipe = os.path.join( + recipe_launcher_dir.name, + "recipes_collection", + "recipes", + training_recipe + ".yaml", + ) + if os.path.isfile(recipe): + shutil.copy(recipe, temp_local_recipe) + else: + raise ValueError(f"Recipe {training_recipe} not found.") + + recipe = OmegaConf.load(temp_local_recipe) + os.unlink(temp_local_recipe) + recipe = OmegaConf.merge(recipe, recipe_overrides) + return recipe + + +def _register_custom_resolvers(): + """Register custom resolvers for OmegaConf.""" + if not OmegaConf.has_resolver("multiply"): + OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) + if not OmegaConf.has_resolver("divide_ceil"): + OmegaConf.register_new_resolver( + "divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True + ) + if not OmegaConf.has_resolver("divide_floor"): + OmegaConf.register_new_resolver( + "divide_floor", lambda x, y: int(math.floor(x / y)), replace=True + ) + if not OmegaConf.has_resolver("add"): + OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) + + +def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): + """Get the model base name and script for the training recipe.""" + + model_type_to_script = { + "llama": ("llama", "llama_pretrain.py"), + "mistral": ("mistral", "mistral_pretrain.py"), + "mixtral": ("mixtral", "mixtral_pretrain.py"), + "deepseek": ("deepseek", "deepseek_pretrain.py"), + "gpt_oss": ("custom_model", "custom_pretrain.py"), + } + + for key in model_type_to_script: + if model_type.startswith(key): + model_type = key + break + + if model_type not in model_type_to_script: + raise ValueError(f"Model type {model_type} not supported") + + return model_type_to_script[model_type][0], model_type_to_script[model_type][1] + + +def _configure_gpu_args( + training_recipes_cfg: Dict[str, Any], + region_name: str, + recipe: DictConfig, + recipe_train_dir: tempfile.TemporaryDirectory, +) -> Dict[str, Any]: + """Configure arguments specific to GPU.""" + source_code = SourceCode() + args = dict() + + adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get( + "adapter_repo" + ) + _run_clone_command_silent(adapter_repo, recipe_train_dir.name) + + if "model" not in recipe: + raise ValueError("Supplied recipe does not contain required field model.") + if "model_type" not in recipe["model"]: + raise ValueError("Supplied recipe does not contain required field model_type.") + model_type = recipe["model"]["model_type"] + + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) + + source_code.source_dir = os.path.join(recipe_train_dir.name, "examples", model_base_name) + source_code.entry_script = script + + gpu_image_cfg = training_recipes_cfg.get("gpu_image") + if isinstance(gpu_image_cfg, str): + training_image = gpu_image_cfg + else: + training_image = retrieve( + gpu_image_cfg.get("framework"), + region=region_name, + version=gpu_image_cfg.get("version"), + image_scope="training", + **gpu_image_cfg.get("additional_args"), + ) + + # Setting dummy parameters for now + torch_distributed = Torchrun(smp=SMP(random_seed="123456")) + args.update( + { + "source_code": source_code, + "training_image": training_image, + "distributed": torch_distributed, + } + ) + return args + + +def _configure_trainium_args( + training_recipes_cfg: Dict[str, Any], + region_name: str, + recipe_train_dir: tempfile.TemporaryDirectory, +) -> Dict[str, Any]: + """Configure arguments specific to Trainium.""" + source_code = SourceCode() + args = dict() + + _run_clone_command_silent(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir.name) + + source_code.source_dir = os.path.join(recipe_train_dir.name, "examples") + source_code.entry_script = "training_orchestrator.py" + neuron_image_cfg = training_recipes_cfg.get("neuron_image") + if isinstance(neuron_image_cfg, str): + training_image = neuron_image_cfg + else: + training_image = retrieve( + neuron_image_cfg.get("framework"), + region=region_name, + version=neuron_image_cfg.get("version"), + image_scope="training", + **neuron_image_cfg.get("additional_args"), + ) + + args.update( + { + "source_code": source_code, + "training_image": training_image, + "distributed": Torchrun(), + } + ) + return args + + +def _is_nova_recipe( + recipe: DictConfig, +) -> bool: + """Check if the recipe is a Nova recipe. + + A recipe is considered a Nova recipe if it meets either of the following conditions: + + 1. It has a run section with: + - A model_type that includes "amazon.nova" + - A model_name_or_path field + + OR + + 2. It has a training_config section with: + - A distillation_data field + + Args: + recipe (DictConfig): The loaded recipe configuration + + Returns: + bool: True if the recipe is a Nova recipe, False otherwise + """ + run_config = recipe.get("run", {}) + model_type = run_config.get("model_type", "").lower() + has_nova_model = ( + model_type and "amazon.nova" in model_type and "model_name_or_path" in run_config + ) + + # Check for distillation data + training_config = recipe.get("training_config", {}) + has_distillation = training_config.get("distillation_data") is not None + return bool(has_nova_model) or bool(has_distillation) + + +def _get_args_from_nova_recipe( + recipe: DictConfig, + compute: Compute, + role: Optional[str] = None, +) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: + if not compute.instance_count and not recipe.get("run", {}).get("replicas", None): + raise ValueError("Must set ``instance_type`` in compute or ``replicas`` in recipe.") + compute.instance_count = compute.instance_count or recipe.get("run", {}).get("replicas") + + args = dict() + args.update({"hyperparameters": {}}) + + run_config = recipe.get("run", {}) + model_name_or_path = run_config.get("model_name_or_path") + if model_name_or_path: + if model_name_or_path.startswith("s3://"): + args["hyperparameters"]["base_model_location"] = model_name_or_path + else: + args["hyperparameters"]["base_model"] = model_name_or_path + + # Handle distillation configuration + training_config = recipe.get("training_config", {}) + distillation_data = training_config.get("distillation_data") + if bool(distillation_data): + args["hyperparameters"]["distillation_data"] = distillation_data + if not role: + raise ValueError("Must provide 'role' parameter when using Nova distillation") + args["hyperparameters"]["role_arn"] = role + + kms_key = training_config.get("kms_key") + if kms_key is None: + raise ValueError( + 'Nova distillation job recipe requires "kms_key" field in "training_config"' + ) + args["hyperparameters"]["kms_key"] = kms_key + + _register_custom_resolvers() + + # Resolve Final Recipe + final_recipe = _try_resolve_recipe(recipe) + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "recipes") + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "training") + if final_recipe is None: + raise RuntimeError("Could not resolve provided recipe.") + + # Save Final Recipe to tmp dir + recipe_local_dir = tempfile.TemporaryDirectory(prefix="recipe_") + final_recipe_path = os.path.join(recipe_local_dir.name, SM_RECIPE_YAML) + OmegaConf.save(config=final_recipe, f=final_recipe_path) + + args.update( + { + "compute": compute, + "training_image": None, + "source_code": None, + "distributed": None, + } + ) + return args, recipe_local_dir + + +def _get_args_from_recipe( + training_recipe: Union[str, DictConfig], + compute: Compute, + region_name: str, + recipe_overrides: Optional[Dict[str, Any]], + requirements: Optional[str], + role: Optional[str] = None, +) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: + """Get arguments for ModelTrainer from a training recipe. + + Returns a dictionary of arguments to be used with ModelTrainer like: + ```python + { + "source_code": SourceCode, + "training_image": str, + "distributed": DistributedConfig, + "compute": Compute, + "hyperparameters": Dict[str, Any], + } + ``` + + Args: + training_recipe (Union[str, Dict[str, Any]]): + Name of the training recipe or path to the recipe file or loaded recipe Dict. + compute (Compute): + Compute configuration for training. + region_name (str): + Name of the AWS region. + recipe_overrides (Optional[Dict[str, Any]]): + Overrides for the training recipe. + requirements (Optional[str]): + Path to the requirements file. + """ + if compute.instance_type is None: + raise ValueError("Must set `instance_type` in compute when using training recipes.") + + training_recipes_cfg = _load_recipes_cfg() + if isinstance(training_recipe, str): + recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg) + else: + recipe = training_recipe + if _is_nova_recipe(recipe): + args, recipe_local_dir = _get_args_from_nova_recipe(recipe, compute, role=role) + return args, recipe_local_dir + + if "trainer" not in recipe: + raise ValueError("Supplied recipe does not contain required field trainer.") + + # Set instance_count + if compute.instance_count and "num_nodes" in recipe["trainer"]: + logger.warning( + f"Using Compute to set instance_count:\n{compute}." + "\nIgnoring trainer -> num_nodes in recipe." + ) + if compute.instance_count is None: + if "num_nodes" not in recipe["trainer"]: + raise ValueError( + "Must provide Compute with instance_count or set trainer -> num_nodes in recipe." + ) + compute.instance_count = recipe["trainer"]["num_nodes"] + + if requirements and not os.path.isfile(requirements): + raise ValueError(f"Recipe requirements file {requirements} not found.") + + # Get Training Image, SourceCode, and distributed args + device_type = _determine_device_type(compute.instance_type) + recipe_train_dir = tempfile.TemporaryDirectory(prefix="training_") + if device_type == "gpu": + args = _configure_gpu_args(training_recipes_cfg, region_name, recipe, recipe_train_dir) + elif device_type == "trainium": + args = _configure_trainium_args(training_recipes_cfg, region_name, recipe_train_dir) + else: + raise ValueError(f"Devices of type {device_type} are not supported with training recipes.") + + _register_custom_resolvers() + + # Resolve Final Recipe + final_recipe = _try_resolve_recipe(recipe) + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "recipes") + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "training") + if final_recipe is None: + raise RuntimeError("Could not resolve provided recipe.") + + # Save Final Recipe to source_dir + OmegaConf.save( + config=final_recipe, f=os.path.join(args["source_code"].source_dir, SM_RECIPE_YAML) + ) + + # If recipe_requirements is provided, copy it to source_dir + if requirements: + shutil.copy(requirements, args["source_code"].source_dir) + args["source_code"].requirements = os.path.basename(requirements) + + # Update args with compute and hyperparameters + args.update( + { + "compute": compute, + "hyperparameters": {"config-path": ".", "config-name": SM_RECIPE_YAML}, + } + ) + + return args, recipe_train_dir diff --git a/src/sagemaker/modules/types.py b/src/sagemaker/modules/types.py new file mode 100644 index 0000000000..18bdcce3bd --- /dev/null +++ b/src/sagemaker/modules/types.py @@ -0,0 +1,19 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Types module.""" +from __future__ import absolute_import + +from typing import Union +from sagemaker.modules.configs import S3DataSource, FileSystemDataSource + +DataSourceType = Union[str, S3DataSource, FileSystemDataSource] diff --git a/src/sagemaker/modules/utils.py b/src/sagemaker/modules/utils.py new file mode 100644 index 0000000000..502f1bbc74 --- /dev/null +++ b/src/sagemaker/modules/utils.py @@ -0,0 +1,194 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utils module.""" +from __future__ import absolute_import + +import os +import json +import subprocess +import tempfile +from pathlib import Path + +from datetime import datetime +from typing import Literal, Any + +from sagemaker_core.shapes import Unassigned +from sagemaker.modules import logger + + +def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: + """Check if the path is a valid S3 URI. + + This method checks if the path is a valid S3 URI. If the path_type is specified, + it will also check if the path is a file or a directory. + This method does not check if the S3 bucket or object exists. + + Args: + path (str): S3 URI to validate + path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate. + Defaults to "Any". + + Returns: + bool: True if the path is a valid S3 URI, False otherwise + """ + # Check if the path is a valid S3 URI + if not path.startswith("s3://"): + return False + + if path_type == "File": + # If it's a file, it should not end with a slash + return not path.endswith("/") + if path_type == "Directory": + # If it's a directory, it should end with a slash + return path.endswith("/") + + return path_type == "Any" + + +def _is_valid_path(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool: + """Check if the path is a valid local path. + + Args: + path (str): Local path to validate + path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate. + Defaults to "Any". + + Returns: + bool: True if the path is a valid local path, False otherwise + """ + if not os.path.exists(path): + return False + + if path_type == "File": + return os.path.isfile(path) + if path_type == "Directory": + return os.path.isdir(path) + + return path_type == "Any" + + +def _get_unique_name(base, max_length=63): + """Generate a unique name based on the base name. + + This method generates a unique name based on the base name. + The unique name is generated by appending the current timestamp + to the base name. + + Args: + base (str): The base name to use + max_length (int): The maximum length of the unique name. Defaults to 63. + + Returns: + str: The unique name + """ + current_time = datetime.now().strftime("%Y%m%d%H%M%S") + base = base.replace("_", "-") + unique_name = f"{base}-{current_time}" + unique_name = unique_name[:max_length] # Truncate to max_length + return unique_name + + +def _get_repo_name_from_image(image: str) -> str: + """Get the repository name from the image URI. + + Example: + ``` python + _get_repo_name_from_image("123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo:latest") + # Returns "my-repo" + ``` + + Args: + image (str): The image URI + + Returns: + str: The repository name + """ + return image.split("/")[-1].split(":")[0] + + +def convert_unassigned_to_none(instance) -> Any: + """Convert Unassigned values to None for any instance.""" + for name, value in instance.__dict__.items(): + if isinstance(value, Unassigned): + setattr(instance, name, None) + return instance + + +def safe_serialize(data): + """Serialize the data without wrapping strings in quotes. + + This function handles the following cases: + 1. If `data` is a string, it returns the string as-is without wrapping in quotes. + 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns + the JSON-encoded string using `json.dumps()`. + 3. If `data` cannot be serialized (e.g., a custom object), it returns the string + representation of the data using `str(data)`. + + Args: + data (Any): The data to serialize. + + Returns: + str: The serialized JSON-compatible string or the string representation of the input. + """ + if isinstance(data, str): + return data + try: + return json.dumps(data) + except TypeError: + return str(data) + + +def _run_clone_command_silent(repo_url, dest_dir): + """Run the 'git clone' command with the repo url and the directory to clone the repo into. + + Args: + repo_url (str): Git repo url to be cloned. + dest_dir: (str): Local path where the repo should be cloned into. + + Raises: + CalledProcessError: If failed to clone git repo. + """ + my_env = os.environ.copy() + if repo_url.startswith("https://"): + try: + my_env["GIT_TERMINAL_PROMPT"] = "0" + subprocess.check_call( + ["git", "clone", repo_url, dest_dir], + env=my_env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except subprocess.CalledProcessError as e: + logger.error(f"Failed to clone repository: {repo_url}") + logger.error(f"Error output:\n{e}") + raise + elif repo_url.startswith("git@") or repo_url.startswith("ssh://"): + try: + with tempfile.TemporaryDirectory() as tmp_dir: + custom_ssh_executable = Path(tmp_dir) / "ssh_batch" + with open(custom_ssh_executable, "w") as pipe: + print("#!/bin/sh", file=pipe) + print("ssh -oBatchMode=yes $@", file=pipe) + os.chmod(custom_ssh_executable, 0o511) + my_env["GIT_SSH"] = str(custom_ssh_executable) + subprocess.check_call( + ["git", "clone", repo_url, dest_dir], + env=my_env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except subprocess.CalledProcessError as e: + del my_env["GIT_SSH"] + logger.error(f"Failed to clone repository: {repo_url}") + logger.error(f"Error output:\n{e}") + raise diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index 9c1e6ac4f4..43a3588e6f 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -126,6 +126,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition set. @@ -154,6 +155,7 @@ def prepare_container_def( model_data_url=self.model_data_prefix, container_mode=self.container_mode, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def deploy( @@ -221,7 +223,7 @@ def deploy( Amazon SageMaker Model Monitoring. Default: None. Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index 7d0ce2d494..5126a37a85 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -50,7 +50,7 @@ def __init__( hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, distribution: Optional[Dict[str, str]] = None, - **kwargs + **kwargs, ): """This ``Estimator`` executes an MXNet script in a managed MXNet execution environment. @@ -84,8 +84,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on @@ -222,7 +222,7 @@ def create_model( source_dir=None, dependencies=None, image_uri=None, - **kwargs + **kwargs, ): """Create a SageMaker ``MXNetModel`` object that can be deployed to an ``Endpoint``. @@ -283,7 +283,7 @@ def create_model( sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), dependencies=(dependencies or self.dependencies), - **kwargs + **kwargs, ) if entry_point is None: diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 714b0db945..fa0c691d2d 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import packaging.version @@ -29,12 +29,17 @@ ) from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.mxnet import defaults from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.utils import to_string from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -63,9 +68,9 @@ def __init__( manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. - serializer (callable): Optional. Default serializes input data to + serializer (Callable): Optional. Default serializes input data to json. Handles dicts, lists, and numpy arrays. - deserializer (callable): Optional. Default parses the response using + deserializer (Callable): Optional. Default parses the response using ``json.load(...)``. component_name (str): Optional. Name of the Amazon SageMaker inference component corresponding to the predictor. @@ -93,9 +98,9 @@ def __init__( framework_version: str = _LOWEST_MMS_VERSION, py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = MXNetPredictor, + predictor_cls: Optional[Callable] = MXNetPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, - **kwargs + **kwargs, ): """Initialize an MXNetModel. @@ -122,7 +127,7 @@ def __init__( If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If ``image_uri`` is also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -177,6 +182,8 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -228,6 +235,9 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -268,6 +278,8 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + model_life_cycle=model_life_cycle, ) def prepare_container_def( @@ -276,6 +288,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition with framework configuration. @@ -329,6 +342,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/partner_app/__init__.py b/src/sagemaker/partner_app/__init__.py new file mode 100644 index 0000000000..b9ef202bc7 --- /dev/null +++ b/src/sagemaker/partner_app/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""__init__ file for sagemaker.partner_app.auth_provider""" +from __future__ import absolute_import + +from sagemaker.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401 diff --git a/src/sagemaker/partner_app/auth_provider.py b/src/sagemaker/partner_app/auth_provider.py new file mode 100644 index 0000000000..2e0d7da94c --- /dev/null +++ b/src/sagemaker/partner_app/auth_provider.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""The SageMaker partner application SDK auth module""" +from __future__ import absolute_import + +import os +import re +from typing import Dict, Tuple + +import boto3 +from botocore.auth import SigV4Auth +from botocore.credentials import Credentials +from requests.auth import AuthBase +from requests.models import PreparedRequest +from sagemaker.partner_app.auth_utils import PartnerAppAuthUtils + +SERVICE_NAME = "sagemaker" +AWS_PARTNER_APP_ARN_REGEX = r"arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:partner-app\/.*" + + +class RequestsAuth(AuthBase): + """Requests authentication class for SigV4 header generation. + + This class is used to generate the SigV4 header and add it to the request headers. + """ + + def __init__(self, sigv4: SigV4Auth, app_arn: str): + """Initialize the RequestsAuth class. + + Args: + sigv4 (SigV4Auth): SigV4Auth object + app_arn (str): Application ARN + """ + self.sigv4 = sigv4 + self.app_arn = app_arn + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + """Callback function to generate the SigV4 header and add it to the request headers. + + Args: + request (PreparedRequest): PreparedRequest object + + Returns: + PreparedRequest: PreparedRequest object with the SigV4 header added + """ + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + sigv4=self.sigv4, + app_arn=self.app_arn, + url=request.url, + method=request.method, + headers=request.headers, + body=request.body, + ) + request.url = url + request.headers.update(signed_headers) + + return request + + +class PartnerAppAuthProvider: + """The SageMaker partner application SDK auth provider class""" + + def __init__(self, credentials: Credentials = None): + """Initialize the PartnerAppAuthProvider class. + + Args: + credentials (Credentials, optional): AWS credentials. Defaults to None. + Raises: + ValueError: If the AWS_PARTNER_APP_ARN environment variable is not set or is invalid. + """ + self.app_arn = os.getenv("AWS_PARTNER_APP_ARN") + if self.app_arn is None: + raise ValueError("Must specify the AWS_PARTNER_APP_ARN environment variable") + + app_arn_regex_match = re.search(AWS_PARTNER_APP_ARN_REGEX, self.app_arn) + if app_arn_regex_match is None: + raise ValueError("Must specify a valid AWS_PARTNER_APP_ARN environment variable") + + split_arn = self.app_arn.split(":") + self.region = split_arn[3] + + self.credentials = ( + credentials if credentials is not None else boto3.Session().get_credentials() + ) + self.sigv4 = SigV4Auth(self.credentials, SERVICE_NAME, self.region) + + def get_signed_request( + self, url: str, method: str, headers: dict, body: object + ) -> Tuple[str, Dict[str, str]]: + """Generate the SigV4 header and add it to the request headers. + + Args: + url (str): Request URL + method (str): HTTP method + headers (dict): Request headers + body (object): Request body + + Returns: + tuple: (url, headers) + """ + return PartnerAppAuthUtils.get_signed_request( + sigv4=self.sigv4, + app_arn=self.app_arn, + url=url, + method=method, + headers=headers, + body=body, + ) + + def get_auth(self) -> RequestsAuth: + """Returns the callback class (RequestsAuth) used for generating the SigV4 header. + + Returns: + RequestsAuth: Callback Object which will calculate the header just before + request submission. + """ + + return RequestsAuth(self.sigv4, os.environ["AWS_PARTNER_APP_ARN"]) diff --git a/src/sagemaker/partner_app/auth_utils.py b/src/sagemaker/partner_app/auth_utils.py new file mode 100644 index 0000000000..eb1dcacaa9 --- /dev/null +++ b/src/sagemaker/partner_app/auth_utils.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +"""Partner App Auth Utils Module""" + +from __future__ import absolute_import + +from hashlib import sha256 +import functools +from typing import Tuple, Dict + +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest + +HEADER_CONNECTION = "Connection" +HEADER_X_AMZ_TARGET = "X-Amz-Target" +HEADER_AUTHORIZATION = "Authorization" +HEADER_PARTNER_APP_SERVER_ARN = "X-SageMaker-Partner-App-Server-Arn" +HEADER_PARTNER_APP_AUTHORIZATION = "X-Amz-Partner-App-Authorization" +HEADER_X_AMZ_CONTENT_SHA_256 = "X-Amz-Content-SHA256" +CALL_PARTNER_APP_API_ACTION = "SageMaker.CallPartnerAppApi" + +PAYLOAD_BUFFER = 1024 * 1024 +EMPTY_SHA256_HASH = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" +UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD" + + +class PartnerAppAuthUtils: + """Partner App Auth Utils Class""" + + @staticmethod + def get_signed_request( + sigv4: SigV4Auth, app_arn: str, url: str, method: str, headers: dict, body: object + ) -> Tuple[str, Dict[str, str]]: + """Generate the SigV4 header and add it to the request headers. + + Args: + sigv4 (SigV4Auth): SigV4Auth object + app_arn (str): Application ARN + url (str): Request URL + method (str): HTTP method + headers (dict): Request headers + body (object): Request body + Returns: + tuple: (url, headers) + """ + # Move API key to X-Amz-Partner-App-Authorization + if HEADER_AUTHORIZATION in headers: + headers[HEADER_PARTNER_APP_AUTHORIZATION] = headers[HEADER_AUTHORIZATION] + + # App Arn + headers[HEADER_PARTNER_APP_SERVER_ARN] = app_arn + + # IAM Action + headers[HEADER_X_AMZ_TARGET] = CALL_PARTNER_APP_API_ACTION + + # Body + headers[HEADER_X_AMZ_CONTENT_SHA_256] = PartnerAppAuthUtils.get_body_header(body) + + # Connection header is excluded from server-side signature calculation + connection_header = headers[HEADER_CONNECTION] if HEADER_CONNECTION in headers else None + + if HEADER_CONNECTION in headers: + del headers[HEADER_CONNECTION] + + # Spaces are encoded as %20 + url = url.replace("+", "%20") + + # Calculate SigV4 header + aws_request = AWSRequest( + method=method, + url=url, + headers=headers, + data=body, + ) + sigv4.add_auth(aws_request) + + # Reassemble headers + final_headers = dict(aws_request.headers.items()) + if connection_header is not None: + final_headers[HEADER_CONNECTION] = connection_header + + return (url, final_headers) + + @staticmethod + def get_body_header(body: object): + """Calculate the body header for the SigV4 header. + + Args: + body (object): Request body + """ + if body and hasattr(body, "seek"): + position = body.tell() + read_chunksize = functools.partial(body.read, PAYLOAD_BUFFER) + checksum = sha256() + for chunk in iter(read_chunksize, b""): + checksum.update(chunk) + hex_checksum = checksum.hexdigest() + body.seek(position) + return hex_checksum + + if body and not isinstance(body, bytes): + # Body is of a class we don't recognize, so don't sign the payload + return UNSIGNED_PAYLOAD + + if body: + # The request serialization has ensured that + # request.body is a bytes() type. + return sha256(body).hexdigest() + + # Body is None + return EMPTY_SHA256_HASH diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 06d2ecfcde..403445525b 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -32,6 +32,7 @@ def retrieve_all_examples( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, @@ -78,11 +79,12 @@ def retrieve_all_examples( unserialized_payload_dict: Optional[Dict[str, JumpStartSerializablePayload]] = ( artifacts._retrieve_example_payloads( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + region=region, + hub_arn=hub_arn, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, ) @@ -123,6 +125,7 @@ def retrieve_example( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, @@ -168,6 +171,7 @@ def retrieve_example( region=region, model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, serialize=serialize, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 3bfdb1a594..b36cd4e917 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -13,10 +13,12 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Optional, Dict, List, Union +from typing import Callable, Optional, Dict, List, Union import sagemaker from sagemaker import ModelMetrics, Model +from sagemaker import local +from sagemaker import session from sagemaker.config import ( ENDPOINT_CONFIG_KMS_KEY_ID_PATH, MODEL_VPC_CONFIG_PATH, @@ -26,6 +28,11 @@ ) from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model import ModelPackage +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.session import Session from sagemaker.utils import ( name_from_image, @@ -49,7 +56,7 @@ def __init__( self, models: List[Model], role: str = None, - predictor_cls: Optional[callable] = None, + predictor_cls: Optional[Callable] = None, name: Optional[str] = None, vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, sagemaker_session: Optional[Session] = None, @@ -70,7 +77,7 @@ def __init__( endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. - predictor_cls (callable[string, sagemaker.session.Session]): A + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor (default: None). If not None, ``deploy`` will return the result of invoking this function on the created endpoint name. @@ -225,7 +232,7 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: - callable[string, sagemaker.session.Session] or None: Invocation of + Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ @@ -361,6 +368,7 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -412,6 +420,8 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). Returns: If ``sagemaker_session`` is a ``PipelineSession`` instance, returns pipeline step @@ -460,9 +470,21 @@ def register( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + ) + + model_package = self.sagemaker_session.create_model_package_from_containers( + **model_pkg_args ) - self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) + if model_package is not None and "ModelPackageArn" in model_package: + return ModelPackage( + role=self.role, + model_package_arn=model_package.get("ModelPackageArn"), + sagemaker_session=self.sagemaker_session, + predictor_cls=self.predictor_cls, + ) + return None def transformer( self, @@ -540,3 +562,16 @@ def delete_model(self): raise ValueError("The SageMaker model must be created before attempting to delete.") self.sagemaker_session.delete_model(self.name) + + def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): + """Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already. + + The type of session object is determined by the instance type. + """ + if self.sagemaker_session: + return + + if instance_type in ("local", "local_gpu"): + self.sagemaker_session = local.LocalSession(sagemaker_config=self._sagemaker_config) + else: + self.sagemaker_session = session.Session(sagemaker_config=self._sagemaker_config) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 6f846bba65..df8554f7e8 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -18,7 +18,7 @@ from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.factory.model import get_default_predictor -from sagemaker.jumpstart.session_utils import get_model_id_version_from_endpoint +from sagemaker.jumpstart.session_utils import get_model_info_from_endpoint from sagemaker.session import Session @@ -40,9 +40,11 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + config_name: Optional[str] = None, ) -> Predictor: """Retrieves the default predictor for the model matching the given arguments. @@ -58,6 +60,8 @@ def retrieve_default( retrieve the default predictor. (Default: None). model_version (str): The version of the model for which to retrieve the default predictor. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -65,6 +69,8 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + config_name (Optional[str]): The name of the configuration to use for the + predictor. (Default: None) Returns: Predictor: The default predictor to use for the model. @@ -78,9 +84,9 @@ def retrieve_default( inferred_model_id, inferred_model_version, inferred_inference_component_name, - ) = get_model_id_version_from_endpoint( - endpoint_name, inference_component_name, sagemaker_session - ) + inferred_config_name, + _, + ) = get_model_info_from_endpoint(endpoint_name, inference_component_name, sagemaker_session) if not inferred_model_id: raise ValueError( @@ -92,6 +98,7 @@ def retrieve_default( model_id = inferred_model_id model_version = model_version or inferred_model_version or "*" inference_component_name = inference_component_name or inferred_inference_component_name + config_name = config_name or inferred_config_name or None else: model_version = model_version or "*" @@ -105,9 +112,11 @@ def retrieve_default( predictor=predictor, model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) diff --git a/src/sagemaker/predictor_async.py b/src/sagemaker/predictor_async.py index ef70b93599..783d034011 100644 --- a/src/sagemaker/predictor_async.py +++ b/src/sagemaker/predictor_async.py @@ -271,6 +271,7 @@ def _check_output_and_failure_paths(self, output_path, failure_path, waiter_conf output_file_found = threading.Event() failure_file_found = threading.Event() + waiter_error_catched = threading.Event() def check_output_file(): try: @@ -282,7 +283,7 @@ def check_output_file(): ) output_file_found.set() except WaiterError: - pass + waiter_error_catched.set() def check_failure_file(): try: @@ -294,7 +295,7 @@ def check_failure_file(): ) failure_file_found.set() except WaiterError: - pass + waiter_error_catched.set() output_thread = threading.Thread(target=check_output_file) failure_thread = threading.Thread(target=check_failure_file) @@ -302,7 +303,11 @@ def check_failure_file(): output_thread.start() failure_thread.start() - while not output_file_found.is_set() and not failure_file_found.is_set(): + while ( + not output_file_found.is_set() + and not failure_file_found.is_set() + and not waiter_error_catched.is_set() + ): time.sleep(1) if output_file_found.is_set(): @@ -310,17 +315,15 @@ def check_failure_file(): result = self.predictor._handle_response(response=s3_object) return result - failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key) - failure_response = self.predictor._handle_response(response=failure_object) + if failure_file_found.is_set(): + failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key) + failure_response = self.predictor._handle_response(response=failure_object) + raise AsyncInferenceModelError(message=failure_response) - raise ( - AsyncInferenceModelError(message=failure_response) - if failure_file_found.is_set() - else PollingTimeoutError( - message="Inference could still be running", - output_path=output_path, - seconds=waiter_config.delay * waiter_config.max_attempts, - ) + raise PollingTimeoutError( + message="Inference could still be running", + output_path=output_path, + seconds=waiter_config.delay * waiter_config.max_attempts, ) def update_endpoint( diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 7b16e3cba3..103be47caf 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -17,51 +17,51 @@ and interpretation on Amazon SageMaker. """ from __future__ import absolute_import - +import logging import os import pathlib -import logging +import re +from copy import copy from textwrap import dedent from typing import Dict, List, Optional, Union -from copy import copy import attr - from six.moves.urllib.parse import urlparse from six.moves.urllib.request import url2pathname + from sagemaker import s3 +from sagemaker.apiutils._base_types import ApiObject from sagemaker.config import ( + PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, + PROCESSING_JOB_ENVIRONMENT_PATH, + PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, PROCESSING_JOB_KMS_KEY_ID_PATH, + PROCESSING_JOB_ROLE_ARN_PATH, PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, PROCESSING_JOB_SUBNETS_PATH, - PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH, PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH, - PROCESSING_JOB_ROLE_ARN_PATH, - PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, - PROCESSING_JOB_ENVIRONMENT_PATH, ) +from sagemaker.dataset_definition.inputs import DatasetDefinition, S3Input from sagemaker.job import _Job from sagemaker.local import LocalSession from sagemaker.network import NetworkConfig +from sagemaker.s3 import S3Uploader +from sagemaker.session import Session from sagemaker.utils import ( + Tags, base_name_from_image, + check_and_get_run_experiment_config, + format_tags, get_config_value, name_from_base, - check_and_get_run_experiment_config, - resolve_value_from_config, resolve_class_attribute_from_config, - Tags, - format_tags, + resolve_value_from_config, ) -from sagemaker.session import Session from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.functions import Join from sagemaker.workflow.pipeline_context import runnable_by_pipeline -from sagemaker.workflow.execution_variables import ExecutionVariables -from sagemaker.workflow.entities import PipelineVariable -from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition -from sagemaker.apiutils._base_types import ApiObject -from sagemaker.s3 import S3Uploader logger = logging.getLogger(__name__) @@ -1415,7 +1415,7 @@ class RunArgs(object): class FeatureStoreOutput(ApiObject): """Configuration for processing job outputs in Amazon SageMaker Feature Store.""" - feature_group_name = None + feature_group_name: Optional[str] = None class FrameworkProcessor(ScriptProcessor): @@ -1464,7 +1464,7 @@ def __init__( instance_type (str or PipelineVariable): The type of EC2 instance to use for processing, for example, 'ml.c4.xlarge'. py_version (str): Python version you want to use for executing your - model training code. One of 'py2' or 'py3'. Defaults to 'py3'. Value + model training code. Ex `py38, py39, py310, py311`. Value is ignored when ``image_uri`` is provided. image_uri (str or PipelineVariable): The URI of the Docker image to use for the processing jobs (default: None). @@ -1658,6 +1658,7 @@ def run( # type: ignore[override] job_name: Optional[str] = None, experiment_config: Optional[Dict[str, str]] = None, kms_key: Optional[str] = None, + codeartifact_repo_arn: Optional[str] = None, ): """Runs a processing job. @@ -1758,12 +1759,21 @@ def run( # type: ignore[override] However, the value of `TrialComponentDisplayName` is honored for display in Studio. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). + codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be + logged into before installing dependencies (default: None). Returns: None or pipeline step arguments in case the Processor instance is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ s3_runproc_sh, inputs, job_name = self._pack_and_upload_code( - code, source_dir, dependencies, git_config, job_name, inputs, kms_key + code, + source_dir, + dependencies, + git_config, + job_name, + inputs, + kms_key, + codeartifact_repo_arn, ) # Submit a processing job. @@ -1780,7 +1790,15 @@ def run( # type: ignore[override] ) def _pack_and_upload_code( - self, code, source_dir, dependencies, git_config, job_name, inputs, kms_key=None + self, + code, + source_dir, + dependencies, + git_config, + job_name, + inputs, + kms_key=None, + codeartifact_repo_arn=None, ): """Pack local code bundle and upload to Amazon S3.""" if code.startswith("s3://"): @@ -1821,12 +1839,53 @@ def _pack_and_upload_code( script = estimator.uploaded_code.script_name evaluated_kms_key = kms_key if kms_key else self.output_kms_key s3_runproc_sh = self._create_and_upload_runproc( - script, evaluated_kms_key, entrypoint_s3_uri + script, evaluated_kms_key, entrypoint_s3_uri, codeartifact_repo_arn ) return s3_runproc_sh, inputs, job_name - def _generate_framework_script(self, user_script: str) -> str: + def _get_codeartifact_command(self, codeartifact_repo_arn: str) -> str: + """Build an AWS CLI CodeArtifact command to configure pip. + + The codeartifact_repo_arn property must follow the form + # `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}` + https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html + https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies + + Args: + codeartifact_repo_arn: arn of the codeartifact repository + Returns: + codeartifact command string + """ + + arn_regex = ( + "arn:(?P[^:]+):codeartifact:(?P[^:]+):(?P[^:]+)" + ":repository/(?P[^/]+)/(?P.+)" + ) + m = re.match(arn_regex, codeartifact_repo_arn) + if not m: + raise ValueError("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn)) + domain = m.group("domain") + owner = m.group("account") + repository = m.group("repository") + region = m.group("region") + + logger.info( + "configuring pip to use codeartifact " + "(domain: %s, domain owner: %s, repository: %s, region: %s)", + domain, + owner, + repository, + region, + ) + + return "aws codeartifact login --tool pip --domain {} --domain-owner {} --repository {} --region {}".format( # noqa: E501 pylint: disable=line-too-long + domain, owner, repository, region + ) + + def _generate_framework_script( + self, user_script: str, codeartifact_repo_arn: str = None + ) -> str: """Generate the framework entrypoint file (as text) for a processing job. This script implements the "framework" functionality for setting up your code: @@ -1837,7 +1896,16 @@ def _generate_framework_script(self, user_script: str) -> str: Args: user_script (str): Relative path to ```code``` in the source bundle - e.g. 'process.py'. + codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be + logged into before installing dependencies (default: None). """ + if codeartifact_repo_arn: + codeartifact_login_command = self._get_codeartifact_command(codeartifact_repo_arn) + else: + codeartifact_login_command = ( + "echo 'CodeArtifact repository not specified. Skipping login.'" + ) + return dedent( """\ #!/bin/bash @@ -1849,6 +1917,13 @@ def _generate_framework_script(self, user_script: str) -> str: set -e if [[ -f 'requirements.txt' ]]; then + # Optionally log into CodeArtifact + if ! hash aws 2>/dev/null; then + echo "AWS CLI is not installed. Skipping CodeArtifact login." + else + {codeartifact_login_command} + fi + # Some py3 containers has typing, which may breaks pip install pip uninstall --yes typing @@ -1858,6 +1933,7 @@ def _generate_framework_script(self, user_script: str) -> str: {entry_point_command} {entry_point} "$@" """ ).format( + codeartifact_login_command=codeartifact_login_command, entry_point_command=" ".join(self.command), entry_point=user_script, ) @@ -1933,7 +2009,9 @@ def _set_entrypoint(self, command, user_script_name): ) self.entrypoint = self.framework_entrypoint_command + [user_script_location] - def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): + def _create_and_upload_runproc( + self, user_script, kms_key, entrypoint_s3_uri, codeartifact_repo_arn=None + ): """Create runproc shell script and upload to S3 bucket. If leveraging a pipeline session with optimized S3 artifact paths, @@ -1949,7 +2027,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): from sagemaker.workflow.utilities import _pipeline_config, hash_object if _pipeline_config and _pipeline_config.pipeline_name: - runproc_file_str = self._generate_framework_script(user_script) + runproc_file_str = self._generate_framework_script(user_script, codeartifact_repo_arn) runproc_file_hash = hash_object(runproc_file_str) s3_uri = s3.s3_path_join( "s3://", @@ -1968,7 +2046,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): ) else: s3_runproc_sh = S3Uploader.upload_string_as_file_body( - self._generate_framework_script(user_script), + self._generate_framework_script(user_script, codeartifact_repo_arn), desired_s3_uri=entrypoint_s3_uri, kms_key=kms_key, sagemaker_session=self.sagemaker_session, diff --git a/src/sagemaker/py.typed b/src/sagemaker/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index a4e24d1ff0..208239e368 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -13,12 +13,23 @@ """Placeholder docstring""" from __future__ import absolute_import +import json import logging +import math +import os +import shutil +import tempfile +import time +from datetime import datetime from typing import Union, Optional, Dict +from urllib.request import urlretrieve +import omegaconf +from omegaconf import OmegaConf, dictconfig from packaging.version import Version from sagemaker.estimator import Framework, EstimatorBase +from sagemaker.inputs import TrainingInput, FileSystemInput from sagemaker.fw_utils import ( framework_name_from_image, framework_version_from_tag, @@ -27,15 +38,262 @@ validate_distribution, profiler_config_deprecation_warning, ) +from sagemaker.git_utils import _run_clone_command +from sagemaker.image_uris import retrieve from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig +from sagemaker.session import Session from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") +def _setup_omegaconf_resolvers(): + """Set up omegaconf resolvers for training recipes.""" + if not OmegaConf.has_resolver("multiply"): + OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) + if not OmegaConf.has_resolver("divide_ceil"): + OmegaConf.register_new_resolver( + "divide_ceil", lambda x, y: int(math.ceil(x / y)), replace=True + ) + if not OmegaConf.has_resolver("divide_floor"): + OmegaConf.register_new_resolver( + "divide_floor", lambda x, y: int(math.floor(x / y)), replace=True + ) + if not OmegaConf.has_resolver("add"): + OmegaConf.register_new_resolver("add", lambda *numbers: sum(numbers)) + + +def _try_resolve_recipe(recipe, key=None): + """Try to resolve recipe and return resolved recipe.""" + if key is not None: + recipe = dictconfig.DictConfig({key: recipe}) + try: + OmegaConf.resolve(recipe) + except omegaconf.errors.OmegaConfBaseException: + return None + if key is None: + return recipe + return recipe[key] + + +def _get_training_recipe_image_uri(image_cfg, region_name): + """Fetch image uri given image spec and region name to use for training.""" + if isinstance(image_cfg, str): + return image_cfg + return retrieve( + image_cfg.get("framework"), + region=region_name, + version=image_cfg.get("version"), + image_scope="training", + **image_cfg.get("additional_args"), + ) + + +def _get_training_recipe_gpu_script(code_dir, recipe, source_dir): + """Return path to training script (entry point) when running a gpu recipe.""" + model_type_to_script = { + "llama_v3": ("llama", "llama_pretrain.py"), + "mistral": ("mistral", "mistral_pretrain.py"), + "mixtral": ("mixtral", "mixtral_pretrain.py"), + "deepseek": ("deepseek", "deepseek_pretrain.py"), + "gpt_oss": ("custom_model", "custom_pretrain.py"), + } + + if "model" not in recipe: + raise ValueError("Supplied recipe does not contain required field model.") + if "model_type" not in recipe["model"]: + raise ValueError("Supplied recipe does not contain required field model_type.") + model_type = recipe["model"]["model_type"] + + for key in model_type_to_script: + if model_type.startswith(key): + model_type = key + break + + if model_type not in model_type_to_script: + raise ValueError(f"Model type {model_type} not supported") + + script_dir = os.path.join(code_dir, "examples", model_type_to_script[model_type][0]) + script = model_type_to_script[model_type][1] + shutil.copyfile(os.path.join(script_dir, script), os.path.join(source_dir, script)) + return script + + +def _get_training_recipe_trainium_script(code_dir, source_dir): + """Return path to training script (entry point) when running a trainium recipe.""" + script_dir = os.path.join(code_dir, "examples") + script = "training_orchestrator.py" + shutil.copytree(script_dir, source_dir, dirs_exist_ok=True) + return script + + +def _is_nova_recipe(recipe): + """Check if the recipe is a Nova recipe. + + A Nova recipe is identified by: + 1. Having a run section + 2. The model_type in run has a "amazon.nova" prefix + 3. The run contains model_name_or_path + + OR + + 1. Has a training_config section + 2. The training config_section has a distillation_data field + + Args: + recipe (OmegaConf): The loaded recipe configuration + + Returns: + bool: True if the recipe is a Nova recipe, False otherwise + """ + # Check for nova model + run_config = recipe.get("run", {}) + model_type = run_config.get("model_type", "").lower() + has_nova_model = ( + model_type and "amazon.nova" in model_type and "model_name_or_path" in run_config + ) + + # Check for distillation data + training_config = recipe.get("training_config", {}) + has_distillation = training_config.get("distillation_data") is not None + + return bool(has_nova_model) or bool(has_distillation) + + +def _recipe_initialize_args(source_dir): + """Initialize the arguments dictionary for recipe setup. + + Args: + source_dir (str): Path to the source directory. + + Returns: + dict: Initialized arguments dictionary. + + Raises: + ValueError: If source_dir is not a local directory. + """ + args = {"hyperparameters": {}} + + if source_dir is None: + args["source_dir"] = "." + else: + if not os.path.exists(source_dir): + raise ValueError("When using training_recipe, source_dir must be a local directory.") + args["source_dir"] = source_dir + + return args + + +def _recipe_get_region_name(kwargs): + """Get the AWS region name from session or create a new session. + + Args: + kwargs (dict): Dictionary of keyword arguments. + + Returns: + str: AWS region name. + """ + if kwargs.get("sagemaker_session") is not None: + return kwargs.get("sagemaker_session").boto_region_name + return Session().boto_region_name + + +def _recipe_load_config(): + """Load the training recipes configuration from JSON file. + + Returns: + dict: Training recipes configuration. + """ + training_recipes_cfg_filename = os.path.join(os.path.dirname(__file__), "training_recipes.json") + with open(training_recipes_cfg_filename) as training_recipes_cfg_file: + return json.load(training_recipes_cfg_file) + + +def _recipe_load_from_yaml(training_recipe, temp_local_recipe): + """Load recipe from a YAML file or URL. + + Args: + training_recipe (str): Path to the training recipe. + temp_local_recipe (str): Path to the temporary local recipe file. + + Raises: + ValueError: If the recipe cannot be fetched. + """ + if os.path.isfile(training_recipe): + shutil.copy(training_recipe, temp_local_recipe) + else: + try: + urlretrieve(training_recipe, temp_local_recipe) + except Exception as e: + raise ValueError( + f"Could not fetch the provided recipe {training_recipe}: exception {str(e)}" + ) + + +def _recipe_load_predefined( + training_recipe, recipe_launcher_dir, temp_local_recipe, training_recipes_cfg +): + """Load a predefined recipe from the recipe launcher. + + Args: + training_recipe (str): Name of the predefined recipe. + recipe_launcher_dir (str): Path to the recipe launcher directory. + temp_local_recipe (str): Path to the temporary local recipe file. + training_recipes_cfg (dict): Training recipes configuration. + + Raises: + ValueError: If the recipe cannot be found. + """ + launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get( + "launcher_repo" + ) + _run_clone_command(launcher_repo, recipe_launcher_dir) + recipe_path = os.path.join( + recipe_launcher_dir, + "recipes_collection", + "recipes", + training_recipe + ".yaml", + ) + if os.path.isfile(recipe_path): + shutil.copy(recipe_path, temp_local_recipe) + else: + raise ValueError(f"Recipe {training_recipe} not found.") + + +def _device_get_distribution(device_type): + """Get the distribution configuration based on device type. + + Args: + device_type (str): Device type (gpu, trainium, or cpu). + + Returns: + dict: Distribution configuration. + + Raises: + ValueError: If the device type is not supported. + """ + if device_type == "gpu": + smp_options = { + "enabled": True, + "parameters": { + "placement_strategy": "cluster", + }, + } + return { + "smdistributed": {"modelparallel": smp_options}, + "torch_distributed": {"enabled": True}, + } + elif device_type == "trainium": + return { + "torch_distributed": {"enabled": True}, + } + else: + return {} + + class PyTorch(Framework): """Handle end-to-end training and deployment of custom PyTorch code.""" @@ -44,9 +302,12 @@ class PyTorch(Framework): LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled" INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" + # [TODO] Add image uris to image_uri_config/_.json and use image_uris.retrieve + # to retrieve the image uri below before GA. + def __init__( self, - entry_point: Union[str, PipelineVariable], + entry_point: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, py_version: Optional[str] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, @@ -54,6 +315,8 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, distribution: Optional[Dict] = None, compiler_config: Optional[TrainingCompilerConfig] = None, + training_recipe: Optional[str] = None, + recipe_overrides: Optional[Dict] = None, **kwargs, ): """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment. @@ -87,9 +350,9 @@ def __init__( unless ``image_uri`` is provided. source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry - point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point file (default: None). If ``source_dir`` is an S3 URI, it must point to a + file with name ``sourcedir.tar.gz``. Structure within this directory are preserved + when training on Amazon SageMaker. Must be a local path when using training_recipe. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on @@ -246,6 +509,14 @@ def __init__( compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`): Configures SageMaker Training Compiler to accelerate training. + training_recipe (str): Training recipe to use. This is a local file path, a url, + or a recipe provided by Amazon SageMaker HyperPod recipes, + such as training/llama/hf_llama3_70b_seq8k_gpu_p5x64_pretrain. + This is required when using recipes. + recipe_overrides (Dict): Dictionary specifying key values to override in the + training_recipe. This is optional when using + Amazon SageMaker HyperPod recipes. + **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. @@ -255,6 +526,31 @@ def __init__( :class:`~sagemaker.estimator.Framework` and :class:`~sagemaker.estimator.EstimatorBase`. """ + self.is_nova_recipe = False + if training_recipe is not None: + if entry_point is not None: + logger.warning("Argument entry_point will be ignored with training_recipe.") + if hyperparameters is not None: + logger.warning("Argument hyperparameters will be ignored with training recipe.") + if distribution is not None: + logger.warning("Argument distribution will be ignored with training_recipe.") + args = self._setup_for_training_recipe( + training_recipe, recipe_overrides, source_dir, kwargs + ) + + if self.is_nova_recipe and image_uri is None: + raise ValueError("Must supply image_uri for nova jobs.") + + entry_point = args["entry_point"] + source_dir = args["source_dir"] + hyperparameters = args["hyperparameters"] + if image_uri is None: + image_uri = args["default_image_uri"] + distribution = args["distribution"] + elif entry_point is None: + raise ValueError( + "Argument entry_point must be set when training_recipe is not provided" + ) validate_version_or_image_args(framework_version, py_version, image_uri) if py_version == "py2": logger.warning( @@ -269,13 +565,32 @@ def __init__( kwargs["enable_sagemaker_metrics"] = True super(PyTorch, self).__init__( - entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs + entry_point, + source_dir, + hyperparameters, + image_uri=image_uri, + is_nova_job=self.is_nova_recipe, + **kwargs, ) if "entry_point" not in kwargs: kwargs["entry_point"] = entry_point if distribution is not None: + # rewrite pytorchddp to smdistributed + if "pytorchddp" in distribution: + if "smdistributed" in distribution: + raise ValueError( + "Cannot use both pytorchddp and smdistributed " + "distribution options together.", + distribution, + ) + + # convert pytorchddp distribution into smdistributed distribution + distribution = distribution.copy() + distribution["smdistributed"] = {"dataparallel": distribution["pytorchddp"]} + del distribution["pytorchddp"] + distribution = validate_distribution( distribution, self.instance_groups, @@ -362,6 +677,72 @@ def hyperparameters(self): return hyperparameters + def fit( + self, + inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, + wait: bool = True, + logs: str = "All", + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + ): + """Train a model using the input training dataset. + + Adds the recipe file to the inputs when a training recipe is used. + + Args: + inputs (str or dict or sagemaker.inputs.TrainingInput or + sagemaker.inputs.FileSystemInput): Information about the training data. + wait (bool): Whether the call should wait until the job completes (default: True). + logs ([str]): A list of strings specifying which logs to print. + job_name (str): Training job name. + experiment_config (dict[str, str]): Experiment management configuration. + + Returns: + None or pipeline step arguments + """ + # Handle recipe upload and input channel creation if we have a recipe + if ( + self.is_nova_recipe is not None + and self.is_nova_recipe + and hasattr(self, "training_recipe_file") + and self.training_recipe_file + ): + # Upload the recipe to S3 if it hasn't been uploaded yet + if not hasattr(self, "recipe_s3_uri") or not self.recipe_s3_uri: + self.recipe_s3_uri = self._upload_recipe_to_s3( + self.sagemaker_session, self.training_recipe_file.name + ) + + # Prepare inputs dictionary + from sagemaker.inputs import TrainingInput + + if inputs is None: + inputs = {} + elif not isinstance(inputs, dict): + inputs = {"training": inputs} + + # Add the recipe channel + recipe_channel_name = "recipe" + inputs[recipe_channel_name] = TrainingInput( + s3_data=os.path.dirname(self.recipe_s3_uri), input_mode="File" + ) + + # Update hyperparameters to reference the recipe location in the container + recipe_filename = os.path.basename(self.training_recipe_file.name) + + self._hyperparameters.update( + { + "sagemaker_recipe_local_path": f"/opt/ml/input/data/{recipe_channel_name}/{recipe_filename}", + } + ) + return super(PyTorch, self).fit( + inputs=inputs, + wait=wait, + logs=logs, + job_name=job_name, + experiment_config=experiment_config, + ) + def create_model( self, model_server_workers=None, @@ -466,3 +847,469 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na ) return init_params + + # The old class methods have been replaced by static methods and module-level functions + + @staticmethod + def _recipe_load(training_recipe, recipe_launcher_dir, training_recipes_cfg): + """Load the recipe from file path, URL, or predefined recipe. + + Args: + training_recipe (str): Path to the training recipe. + recipe_launcher_dir (str): Path to the recipe launcher directory. + training_recipes_cfg (dict): Training recipes configuration. + + Returns: + tuple: Recipe name and loaded recipe. + + Raises: + ValueError: If the recipe cannot be fetched or found. + """ + recipe_name = os.path.splitext(os.path.basename(training_recipe))[0] + temp_local_recipe = tempfile.NamedTemporaryFile(prefix=recipe_name, suffix=".yaml").name + + try: + if training_recipe.endswith(".yaml"): + _recipe_load_from_yaml(training_recipe, temp_local_recipe) + else: + _recipe_load_predefined( + training_recipe, recipe_launcher_dir, temp_local_recipe, training_recipes_cfg + ) + + recipe = OmegaConf.load(temp_local_recipe) + os.unlink(temp_local_recipe) + return recipe_name, recipe + except Exception as e: + if os.path.exists(temp_local_recipe): + os.unlink(temp_local_recipe) + raise e + + @staticmethod + def _device_get_image_uri(args, device_type, recipe_config, region_name, recipe): + """Get the appropriate image URI based on device type. + + Args: + args (dict): Arguments dictionary. + device_type (str): Device type (gpu, trainium, or cpu). + recipe_config (dict): Training recipes configuration. + region_name (str): AWS region name. + recipe (OmegaConf): Recipe configuration. + + Returns: + str: Image URI or None if no image URI was found. + """ + if "default_image_uri" in args: + logger.debug("Image URI already exists") + return args["default_image_uri"] + elif device_type == "gpu": + logger.info("Using GPU training image") + return _get_training_recipe_image_uri(recipe_config.get("gpu_image"), region_name) + elif device_type == "trainium": + logger.info("Using Trainium training image") + return _get_training_recipe_image_uri(recipe_config.get("neuron_image"), region_name) + else: + return None + + @staticmethod + def _recipe_setup_nova(args, recipe): + """Set up configuration for Nova recipes. + + Args: + args (dict): Arguments dictionary. + recipe (OmegaConf): Recipe configuration. + kwargs (dict): Dictionary of keyword arguments. + """ + run_config = recipe.get("run", {}) + model_name_or_path = run_config.get("model_name_or_path") + + # Set hyperparameters based on model_name_or_path + if model_name_or_path: + if model_name_or_path.startswith("s3://"): + args["hyperparameters"]["base_model_location"] = model_name_or_path + else: + args["hyperparameters"]["base_model"] = model_name_or_path + + args["entry_point"] = None + args["source_dir"] = None + + @staticmethod + def _device_validate_and_get_type(kwargs, recipe): + """Validate instance type and determine device type. + + Args: + kwargs (dict): Dictionary of keyword arguments. + recipe (OmegaConf): Recipe configuration. + + Returns: + str: Device type (gpu, trainium, or cpu). + + Raises: + ValueError: If instance_type is not provided or recipe is invalid. + """ + if "instance_type" not in kwargs: + raise ValueError("Must pass instance type to estimator when using training recipes.") + + if not _is_nova_recipe(recipe) and "trainer" not in recipe: + raise ValueError("Supplied recipe does not contain required field trainer.") + + instance_type = kwargs["instance_type"].split(".")[1] + if instance_type.startswith(("p", "g")): + return "gpu" + elif instance_type.startswith("trn"): + return "trainium" + else: + return "cpu" + + @staticmethod + def _device_handle_instance_count(kwargs, recipe): + """Handle instance count configuration. + + Args: + kwargs (dict): Dictionary of keyword arguments. + recipe (OmegaConf): Recipe configuration. + + Raises: + ValueError: If instance_count is not provided and cannot be found in the recipe. + """ + # Check if instance_count is already provided in kwargs + + is_nova = _is_nova_recipe(recipe) + if "instance_count" in kwargs: + # Warn if there are conflicting configurations in the recipe + if "num_nodes" in recipe.get("trainer", {}): + logger.warning( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring trainer -> num_nodes in recipe." + ) + if is_nova and "replicas" in recipe.get("run", {}): + logger.warning( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring run -> replicas in recipe." + ) + return + + # Try to get instance_count from recipe + if "trainer" in recipe and "num_nodes" in recipe["trainer"]: + kwargs["instance_count"] = recipe["trainer"]["num_nodes"] + return + + if is_nova and "run" in recipe and "replicas" in recipe["run"]: + kwargs["instance_count"] = recipe["run"]["replicas"] + return + + # If we get here, we couldn't find instance_count anywhere + raise ValueError( + "Must set either instance_count argument for estimator or " + "set trainer -> num_nodes or run -> replicas in recipe for nova jobs." + ) + + @staticmethod + def _device_get_entry_point_script( + device_type, recipe_train_dir, recipe, source_dir, training_recipes_cfg + ): + """Get the entry point script based on device type. + + Args: + device_type (str): Device type (gpu, trainium, or cpu). + recipe_train_dir (str): Path to the recipe training directory. + recipe (OmegaConf): Recipe configuration. + source_dir (str): Path to the source directory. + training_recipes_cfg (dict): Training recipes configuration. + + Returns: + str: Path to the entry point script or None if not applicable. + """ + if device_type == "gpu": + adapter_repo = os.environ.get("TRAINING_ADAPTER_GIT", None) or training_recipes_cfg.get( + "adapter_repo" + ) + _run_clone_command(adapter_repo, recipe_train_dir) + return _get_training_recipe_gpu_script(recipe_train_dir, recipe, source_dir) + elif device_type == "trainium": + _run_clone_command(training_recipes_cfg.get("neuron_dist_repo"), recipe_train_dir) + return _get_training_recipe_trainium_script(recipe_train_dir, source_dir) + elif device_type == "cpu": + raise ValueError( + f"Devices of type {device_type} are not supported with training recipes." + ) + return None + + def _recipe_resolve_and_save(self, recipe, recipe_name, source_dir): + """Resolve and save the final recipe configuration. + + Args: + recipe (OmegaConf): Recipe configuration. + recipe_name (str): Recipe name. + source_dir (str): Path to the source directory. + + Returns: + OmegaConf: Resolved recipe configuration. + + Raises: + RuntimeError: If the recipe cannot be resolved. + """ + _setup_omegaconf_resolvers() + + # Try different resolution strategies + final_recipe = _try_resolve_recipe(recipe) + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "recipes") + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "training") + if final_recipe is None: + raise RuntimeError("Could not resolve provided recipe.") + + # Save the resolved recipe - this sets an instance attribute + self.training_recipe_file = tempfile.NamedTemporaryFile( + dir=source_dir, + prefix=recipe_name + "_", + suffix=".yaml", + ) + OmegaConf.save(config=final_recipe, f=self.training_recipe_file.name) + + return final_recipe + + def _upload_recipe_to_s3(self, session, recipe_file_path): + """Upload the recipe file to S3. + + Args: + session (sagemaker.session.Session): SageMaker session. + recipe_file_path (str): Path to the recipe file. + + Returns: + str: S3 URI of the uploaded recipe file. + """ + bucket = session.default_bucket() + key_prefix = session.default_bucket_prefix + + recipe_filename = os.path.basename(recipe_file_path) + + readable_date = datetime.fromtimestamp(int(time.time())) + date_format = readable_date.strftime("%Y-%m-%d") + + if key_prefix != "None" and key_prefix is not None: + s3_key = f"{key_prefix}/recipes/{date_format}_{recipe_filename[:-5]}" + else: + s3_key = f"recipes/{date_format}_{recipe_filename[:-5]}" + + # Upload the recipe file to S3 + s3_uri = session.upload_data( + path=recipe_file_path, + bucket=bucket, + key_prefix=os.path.dirname(os.path.join(s3_key, recipe_filename)), + ) + + # Return the full S3 URI to the recipe file + return f"{s3_uri}" + + def _setup_for_training_recipe(self, training_recipe, recipe_overrides, source_dir, kwargs): + """Performs training recipe specific setup and returns recipe specific args. + + Updates kwargs and returns a dictionary of args to use for estimator + initialization and setup when using a training recipe. + + Args: + training_recipe (str): A recipe which is a local file path, a url or a + sagemaker training recipe. + recipe_overrides (Dict): Dictionary specifying key values to override in the + training recipe. + source_dir (str): Path (absolute, or relative) to a directory where to copy + the scripts for training recipe. + kwargs (dict): Dictionary of args used for estimator initialization. + + Returns: + dict containing arg values for estimator initialization and setup. + """ + region_name = _recipe_get_region_name(kwargs) + training_recipes_cfg = _recipe_load_config() + recipe_overrides = recipe_overrides or {} + + # Create temporary directories for recipe processing + with ( + tempfile.TemporaryDirectory(prefix="training_") as recipe_train_dir, + tempfile.TemporaryDirectory(prefix="launcher_") as recipe_launcher_dir, + ): + # Load and process the recipe + recipe_name, recipe = PyTorch._recipe_load( + training_recipe, recipe_launcher_dir, training_recipes_cfg + ) + + # Merge with overrides + recipe = OmegaConf.merge(recipe, recipe_overrides) + + self.is_nova_recipe = _is_nova_recipe(recipe) + if self.is_nova_recipe: + return self._setup_for_nova_recipe( + recipe, + recipe_name, + source_dir, + kwargs, + ) + else: + return self._setup_for_standard_recipe( + recipe, + recipe_name, + source_dir, + kwargs, + recipe_train_dir, + training_recipes_cfg, + region_name, + ) + + def _setup_for_nova_recipe( + self, + recipe, + recipe_name, + source_dir, + kwargs, + ): + """Set up configuration specifically for Nova recipes. + + Args: + recipe (OmegaConf): Recipe configuration. + recipe_name (str): Recipe name. + source_dir (str): Path to the source directory. + kwargs (dict): Dictionary of keyword arguments. + + Returns: + dict: Arguments dictionary for estimator initialization. + """ + # Initialize args + args = _recipe_initialize_args(source_dir) + + # Set up Nova-specific configuration + run_config = recipe.get("run", {}) + model_name_or_path = run_config.get("model_name_or_path") + + # Set hyperparameters based on model_name_or_path + if model_name_or_path: + if model_name_or_path.startswith("s3://"): + args["hyperparameters"]["base_model_location"] = model_name_or_path + else: + args["hyperparameters"]["base_model"] = model_name_or_path + + args["entry_point"] = None + args["source_dir"] = None + args["distribution"] = {} + + logger.info("Remote debugging, profiler and debugger hooks are disabled for Nova recipes.") + kwargs["enable_remote_debug"] = False + kwargs["disable_profiler"] = True + kwargs["debugger_hook_config"] = False + + # Handle instance count for Nova recipes + if "instance_count" in kwargs: + if "replicas" in recipe.get("run", {}): + logger.warning( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring run -> replicas in recipe." + ) + elif "run" in recipe and "replicas" in recipe["run"]: + kwargs["instance_count"] = recipe["run"]["replicas"] + else: + raise ValueError( + "Must set either instance_count argument for estimator or " + "set run -> replicas in recipe for nova jobs." + ) + + training_config = recipe.get("training_config", {}) + is_distillation = training_config.get("distillation_data", {}) + if bool(is_distillation): + args["hyperparameters"]["distillation_data"] = is_distillation + args["hyperparameters"]["role_arn"] = kwargs["role"] + kms_key = training_config.get("kms_key") + if kms_key is None: + ValueError( + 'Nova distillation job recipe requires "kms_key" field in "training_config"' + ) + args["hyperparameters"]["kms_key"] = kms_key + + # Resolve and save the final recipe + self._recipe_resolve_and_save(recipe, recipe_name, args["source_dir"]) + + return args + + def _setup_for_standard_recipe( + self, + recipe, + recipe_name, + source_dir, + kwargs, + recipe_train_dir, + training_recipes_cfg, + region_name, + ): + """Set up configuration for standard (non-Nova) recipes. + + Args: + recipe (OmegaConf): Recipe configuration. + recipe_name (str): Recipe name. + source_dir (str): Path to the source directory. + kwargs (dict): Dictionary of keyword arguments. + recipe_train_dir (str): Path to the recipe training directory. + training_recipes_cfg (dict): Training recipes configuration. + region_name (str): AWS region name. + + Returns: + dict: Arguments dictionary for estimator initialization. + """ + # Initialize args + args = _recipe_initialize_args(source_dir) + + # Validate recipe structure + if "trainer" not in recipe: + raise ValueError("Supplied recipe does not contain required field trainer.") + + # Handle instance count for standard recipes + if "instance_count" in kwargs: + if "num_nodes" in recipe.get("trainer", {}): + logger.warning( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring trainer -> num_nodes in recipe." + ) + elif "trainer" in recipe and "num_nodes" in recipe["trainer"]: + kwargs["instance_count"] = recipe["trainer"]["num_nodes"] + else: + raise ValueError( + "Must set either instance_count argument for estimator or " + "set trainer -> num_nodes in recipe." + ) + + # Determine device type + device_type = PyTorch._device_validate_and_get_type(kwargs, recipe) + + # Get image URI + image_uri = PyTorch._device_get_image_uri( + args, device_type, training_recipes_cfg, region_name, recipe + ) + args["default_image_uri"] = image_uri if image_uri is not None else "" + + # Setup device-specific configuration + args["distribution"] = _device_get_distribution(device_type) + + # Set entry point if not already set + if "entry_point" not in args: + script = PyTorch._device_get_entry_point_script( + device_type, recipe_train_dir, recipe, args["source_dir"], training_recipes_cfg + ) + if script: + args["entry_point"] = os.path.basename(script) + + # Handle container configuration + if "container" in recipe and not recipe["container"]: + logger.warning( + "Ignoring container from training_recipe. Use image_uri arg for estimator." + ) + + # Resolve and save the final recipe + self._recipe_resolve_and_save(recipe, recipe_name, args["source_dir"]) + + # Update hyperparameters with recipe configuration + args["hyperparameters"].update( + { + "config-path": ".", + "config-name": os.path.basename(self.training_recipe_file.name), + } + ) + + return args diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index f490e49375..958327ba08 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import packaging.version @@ -29,12 +29,17 @@ ) from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.pytorch import defaults from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer from sagemaker.utils import to_string from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -94,9 +99,9 @@ def __init__( framework_version: str = "1.3", py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = PyTorchPredictor, + predictor_cls: Optional[Callable] = PyTorchPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, - **kwargs + **kwargs, ): """Initialize a PyTorchModel. @@ -123,7 +128,7 @@ def __init__( If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If ``image_uri`` is also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -179,6 +184,8 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -230,6 +237,9 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -270,6 +280,8 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + model_life_cycle=model_life_cycle, ) def prepare_container_def( @@ -278,6 +290,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """A container definition with framework configuration set in model environment variables. @@ -329,6 +342,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/pytorch/training_recipes.json b/src/sagemaker/pytorch/training_recipes.json new file mode 100644 index 0000000000..5aeccce5a1 --- /dev/null +++ b/src/sagemaker/pytorch/training_recipes.json @@ -0,0 +1,17 @@ +{ + "adapter_repo": "https://github.com/aws/sagemaker-training-adapter-for-nemo.git", + "launcher_repo": "https://github.com/aws/sagemaker-hyperpod-recipes.git", + "neuron_dist_repo": "https://github.com/aws-neuron/neuronx-distributed-training.git", + "gpu_image" : { + "framework": "pytorch-smp", + "version": "2.4.1", + "additional_args": { + "container_version": "cu121" + } + }, + "neuron_image" : { + "framework": "hyperpod-recipes-neuron", + "version": "2.1.2", + "additional_args": {} + } +} diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 0dc69d8647..55b4654aa9 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -40,6 +40,8 @@ from sagemaker.utils import name_from_base, base_from_name from sagemaker.remote_function.spark_config import SparkConfig from sagemaker.remote_function.custom_file_filter import CustomFileFilter +from sagemaker.telemetry.telemetry_logging import _telemetry_emitter +from sagemaker.telemetry.constants import Feature _API_CALL_LIMIT = { "SubmittingIntervalInSecs": 1, @@ -57,6 +59,7 @@ logger = logging_config.get_logger() +@_telemetry_emitter(feature=Feature.REMOTE_FUNCTION, func_name="remote_function.remote") def remote( _func=None, *, @@ -87,6 +90,10 @@ def remote( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, + disable_output_compression: bool = False, + use_torchrun: bool = False, + use_mpirun: bool = False, + nproc_per_node: Optional[int] = None, ): """Decorator for running the annotated function as a SageMaker training job. @@ -202,7 +209,8 @@ def remote( files are accepted and uploaded to S3. instance_count (int): The number of instances to use. Defaults to 1. - NOTE: Remote function does not support instance_count > 1 for non Spark jobs. + NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and + mpirun utilities instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. @@ -275,6 +283,19 @@ def remote( max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. After this amount of time Amazon SageMaker will stop waiting for managed spot training job to complete. Defaults to ``None``. + + disable_output_compression (bool): Optional. When set to true, Model is uploaded to + Amazon S3 without compression after training finishes. + + use_torchrun (bool): Specifies whether to use torchrun for distributed training. + Defaults to ``False``. + + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + + nproc_per_node (int): Optional. Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. """ def _remote(func): @@ -307,14 +328,23 @@ def _remote(func): spark_config=spark_config, use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, + disable_output_compression=disable_output_compression, + use_torchrun=use_torchrun, + use_mpirun=use_mpirun, + nproc_per_node=nproc_per_node, ) @functools.wraps(func) def wrapper(*args, **kwargs): - if instance_count > 1 and not spark_config: + if instance_count > 1 and not ( + (spark_config is not None and not use_torchrun and not use_mpirun) + or (spark_config is None and use_torchrun and not use_mpirun) + or (spark_config is None and not use_torchrun and use_mpirun) + ): raise ValueError( - "Remote function do not support training on multi instances. " + "Remote function do not support training on multi instances " + + "without spark_config or use_torchrun or use_mpirun. " + "Please provide instance_count = 1" ) @@ -518,6 +548,10 @@ def __init__( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, + disable_output_compression: bool = False, + use_torchrun: bool = False, + use_mpirun: bool = False, + nproc_per_node: Optional[int] = None, ): """Constructor for RemoteExecutor @@ -630,7 +664,8 @@ def __init__( files are accepted and uploaded to S3. instance_count (int): The number of instances to use. Defaults to 1. - NOTE: Remote function does not support instance_count > 1 for non Spark jobs. + NOTE: Remote function supports instance_count > 1 for Spark jobs, torchrun and + mpirun utilities instance_type (str): The Amazon Elastic Compute Cloud (EC2) instance type to use to run the SageMaker job. e.g. ml.c4.xlarge. If not provided, a ValueError is thrown. @@ -706,15 +741,33 @@ def __init__( max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. After this amount of time Amazon SageMaker will stop waiting for managed spot training job to complete. Defaults to ``None``. + + disable_output_compression (bool): Optional. When set to true, Model is uploaded to + Amazon S3 without compression after training finishes. + + use_torchrun (bool): Specifies whether to use torchrun for distributed training. + Defaults to ``False``. + + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + + nproc_per_node (int): Optional. Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. """ self.max_parallel_jobs = max_parallel_jobs if self.max_parallel_jobs <= 0: raise ValueError("max_parallel_jobs must be greater than 0.") - if instance_count > 1 and not spark_config: + if instance_count > 1 and not ( + (spark_config is not None and not use_torchrun and not use_mpirun) + or (spark_config is None and use_torchrun and not use_mpirun) + or (spark_config is None and not use_torchrun and use_mpirun) + ): raise ValueError( - "Remote function do not support training on multi instances. " + "Remote function do not support training on multi instances " + + "without spark_config or use_torchrun or use_mpirun. " + "Please provide instance_count = 1" ) @@ -746,6 +799,10 @@ def __init__( spark_config=spark_config, use_spot_instances=use_spot_instances, max_wait_time_in_seconds=max_wait_time_in_seconds, + disable_output_compression=disable_output_compression, + use_torchrun=use_torchrun, + use_mpirun=use_mpirun, + nproc_per_node=nproc_per_node, ) self._state_condition = threading.Condition() diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 5814ee45ff..9000ccda08 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -81,6 +81,7 @@ # runtime script names BOOTSTRAP_SCRIPT_NAME = "bootstrap_runtime_environment.py" +MPI_UTILS_SCRIPT_NAME = "mpi_utils_remote.py" ENTRYPOINT_SCRIPT_NAME = "job_driver.sh" PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" RUNTIME_MANAGER_SCRIPT_NAME = "runtime_environment_manager.py" @@ -130,9 +131,12 @@ export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json printf "INFO: Bootstraping runtime environment.\\n" python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] then @@ -155,13 +159,166 @@ fi printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n" + printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function \\n" $conda_exe run -n $conda_env python -m sagemaker.remote_function.invoke_function "$@" else printf "INFO: No conda env provided. Invoking remote function\\n" + printf "INFO: python -m sagemaker.remote_function.invoke_function \\n" python -m sagemaker.remote_function.invoke_function "$@" fi """ +ENTRYPOINT_MPIRUN_SCRIPT = f""" +#!/bin/bash + +# Entry point for bootstrapping runtime environment and invoking remote function with mpirun + +set -eu + +PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} +export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs +printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" +export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip +printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" + +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json + +printf "INFO: Bootstraping runtime environment.\\n" +python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env + +if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] +then + if [ -f "remote_function_conda_env.txt" ] + then + cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt + fi + printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" + cd {JOB_REMOTE_FUNCTION_WORKSPACE} +fi + +if [ -f "remote_function_conda_env.txt" ] +then + conda_env=$(cat remote_function_conda_env.txt) + + if which mamba >/dev/null; then + conda_exe="mamba" + else + conda_exe="conda" + fi + + if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + + printf "INFO: Invoking remote function with mpirun inside conda environment: $conda_env.\\n" + printf "INFO: $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + + python -m mpi4py -m sagemaker.remote_function.invoke_function \\n" + $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function "$@" + + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 + else + printf "INFO: This is the instance $SM_CURRENT_HOST. mpirun command terminated\\n" + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + fi +else + if [ "$SM_CURRENT_HOST" = "$SM_MASTER_ADDR" ]; then + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + + printf "INFO: No conda env provided. Invoking remote function with mpirun\\n" + printf "INFO: mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function \\n" + + mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ + --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ + -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ + -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ + -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ + $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ + python -m mpi4py -m sagemaker.remote_function.invoke_function "$@" + + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 + else + printf "INFO: This is the instance $SM_CURRENT_HOST.\\n" + python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} + fi +fi +""" + +ENTRYPOINT_TORCHRUN_SCRIPT = f""" +#!/bin/bash + +# Entry point for bootstrapping runtime environment and invoking remote function with torchrun + +set -eu + +PERSISTENT_CACHE_DIR=${{SAGEMAKER_MANAGED_WARMPOOL_CACHE_DIRECTORY:-/opt/ml/cache}} +export CONDA_PKGS_DIRS=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/conda/pkgs +printf "INFO: CONDA_PKGS_DIRS is set to '$CONDA_PKGS_DIRS'\\n" +export PIP_CACHE_DIR=${{PERSISTENT_CACHE_DIR}}/sm_remotefunction_user_dependencies_cache/pip +printf "INFO: PIP_CACHE_DIR is set to '$PIP_CACHE_DIR'\\n" + +printf "INFO: /opt/ml/input/config/resourceconfig.json:\\n" +cat /opt/ml/input/config/resourceconfig.json + +printf "INFO: Bootstraping runtime environment.\\n" +python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{BOOTSTRAP_SCRIPT_NAME} "$@" +source /opt/ml/input/sm_training.env + +if [ -d {JOB_REMOTE_FUNCTION_WORKSPACE} ] +then + if [ -f "remote_function_conda_env.txt" ] + then + cp remote_function_conda_env.txt {JOB_REMOTE_FUNCTION_WORKSPACE}/remote_function_conda_env.txt + fi + printf "INFO: Changing workspace to {JOB_REMOTE_FUNCTION_WORKSPACE}.\\n" + cd {JOB_REMOTE_FUNCTION_WORKSPACE} +fi + +if [ -f "remote_function_conda_env.txt" ] +then + conda_env=$(cat remote_function_conda_env.txt) + + if which mamba >/dev/null; then + conda_exe="mamba" + else + conda_exe="conda" + fi + + printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" + printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ + --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ + -m sagemaker.remote_function.invoke_function \\n" + + $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ + --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ + -m sagemaker.remote_function.invoke_function "$@" +else + printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" + printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function \\n" + + torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.remote_function.invoke_function "$@" +fi +""" + SPARK_ENTRYPOINT_SCRIPT = f""" #!/bin/bash @@ -216,6 +373,10 @@ def __init__( spark_config: SparkConfig = None, use_spot_instances=False, max_wait_time_in_seconds=None, + disable_output_compression: bool = False, + use_torchrun: bool = False, + use_mpirun: bool = False, + nproc_per_node: Optional[int] = None, ): """Initialize a _JobSettings instance which configures the remote job. @@ -397,6 +558,19 @@ def __init__( max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job. After this amount of time Amazon SageMaker will stop waiting for managed spot training job to complete. Defaults to ``None``. + + disable_output_compression (bool): Optional. When set to true, Model is uploaded to + Amazon S3 without compression after training finishes. + + use_torchrun (bool): Specifies whether to use torchrun for distributed training. + Defaults to ``False``. + + use_mpirun (bool): Specifies whether to use mpirun for distributed training. + Defaults to ``False``. + + nproc_per_node (int): Optional. Specifies the number of processes per node for + distributed training. Defaults to ``None``. + This is defined automatically configured on the instance type. """ self.sagemaker_session = sagemaker_session or Session() self.environment_variables = resolve_value_from_config( @@ -555,6 +729,11 @@ def __init__( tags = format_tags(tags) self.tags = self.sagemaker_session._append_sagemaker_config_tags(tags, REMOTE_FUNCTION_TAGS) + self.disable_output_compression = disable_output_compression + self.use_torchrun = use_torchrun + self.use_mpirun = use_mpirun + self.nproc_per_node = nproc_per_node + @staticmethod def _get_default_image(session): """Return Studio notebook image, if in Studio env. Else, base python. @@ -681,6 +860,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non ) logger.info("Creating job: %s", job_name) + job_settings.sagemaker_session.sagemaker_client.create_training_job(**training_job_request) return _Job( @@ -695,7 +875,7 @@ def compile( job_settings: _JobSettings, job_name: str, s3_base_uri: str, - func: callable, + func: Callable, func_args: tuple, func_kwargs: dict, run_info=None, @@ -779,6 +959,8 @@ def compile( output_config = {"S3OutputPath": s3_base_uri} if job_settings.s3_kms_key is not None: output_config["KmsKeyId"] = job_settings.s3_kms_key + if job_settings.disable_output_compression: + output_config["CompressionType"] = "NONE" request_dict["OutputDataConfig"] = output_config container_args = ["--s3_base_uri", s3_base_uri] @@ -800,6 +982,12 @@ def compile( ).to_string(), ] ) + if job_settings.use_torchrun: + container_args.extend(["--distribution", "torchrun"]) + elif job_settings.use_mpirun: + container_args.extend(["--distribution", "mpirun"]) + if job_settings.nproc_per_node is not None and int(job_settings.nproc_per_node) > 0: + container_args.extend(["--user_nproc_per_node", str(job_settings.nproc_per_node)]) if job_settings.s3_kms_key: container_args.extend(["--s3_kms_key", job_settings.s3_kms_key]) @@ -876,6 +1064,8 @@ def compile( request_dict["Environment"].update({"REMOTE_FUNCTION_SECRET_KEY": hmac_key}) extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri) + extended_request = _extend_mpirun_to_request(extended_request, job_settings) + extended_request = _extend_torchrun_to_request(extended_request, job_settings) return extended_request @@ -951,7 +1141,12 @@ def _get_job_name(job_settings, func): def _prepare_and_upload_runtime_scripts( - spark_config: SparkConfig, s3_base_uri: str, s3_kms_key: str, sagemaker_session: Session + spark_config: SparkConfig, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + use_torchrun: bool = False, + use_mpirun: bool = False, ): """Copy runtime scripts to a folder and upload to S3. @@ -967,6 +1162,12 @@ def _prepare_and_upload_runtime_scripts( s3_kms_key (str): kms key used to encrypt the files uploaded to S3. sagemaker_session (str): SageMaker boto client session. + + use_torchrun (bool): Whether to use torchrun or not. + + use_mpirun (bool): Whether to use mpirun or not. + + nproc_per_node (Optional[int]): Number of processes per node """ from sagemaker.workflow.utilities import load_step_compilation_context @@ -988,18 +1189,28 @@ def _prepare_and_upload_runtime_scripts( ) shutil.copy2(spark_script_path, bootstrap_scripts) + if use_torchrun: + entry_point_script = ENTRYPOINT_TORCHRUN_SCRIPT + + if use_mpirun: + entry_point_script = ENTRYPOINT_MPIRUN_SCRIPT + with open(entrypoint_script_path, "w", newline="\n") as file: file.writelines(entry_point_script) bootstrap_script_path = os.path.join( os.path.dirname(__file__), "runtime_environment", BOOTSTRAP_SCRIPT_NAME ) + mpi_utils_path = os.path.join( + os.path.dirname(__file__), "runtime_environment", MPI_UTILS_SCRIPT_NAME + ) runtime_manager_script_path = os.path.join( os.path.dirname(__file__), "runtime_environment", RUNTIME_MANAGER_SCRIPT_NAME ) # copy runtime scripts to tmpdir shutil.copy2(bootstrap_script_path, bootstrap_scripts) + shutil.copy2(mpi_utils_path, bootstrap_scripts) shutil.copy2(runtime_manager_script_path, bootstrap_scripts) upload_path = S3Uploader.upload( @@ -1025,6 +1236,8 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): s3_base_uri=s3_base_uri, s3_kms_key=job_settings.s3_kms_key, sagemaker_session=job_settings.sagemaker_session, + use_torchrun=job_settings.use_torchrun, + use_mpirun=job_settings.use_mpirun, ) input_data_config = [ @@ -1365,6 +1578,64 @@ def _upload_serialized_spark_configuration( return config_file_s3_uri +def _extend_mpirun_to_request( + request_dict: Dict, + job_settings: _JobSettings, +) -> Dict: + """Extend the create training job request with mpirun configuration. + + Args: + request_dict (Dict): create training job request dict. + job_settings (_JobSettings): the job settings. + """ + use_mpirun = job_settings.use_mpirun + instance_count = job_settings.instance_count + + if not use_mpirun: + return request_dict + + if instance_count == 1: + return request_dict + + extended_request = request_dict.copy() + + for input_channel in extended_request["InputDataConfig"]: + s3_data_source = input_channel["DataSource"].get("S3DataSource", None) + if s3_data_source: + s3_data_source["S3DataDistributionType"] = "FullyReplicated" + + return extended_request + + +def _extend_torchrun_to_request( + request_dict: Dict, + job_settings: _JobSettings, +) -> Dict: + """Extend the create training job request with torchrun configuration. + + Args: + request_dict (Dict): create training job request dict. + job_settings (_JobSettings): the job settings. + """ + use_torchrun = job_settings.use_torchrun + instance_count = job_settings.instance_count + + if not use_torchrun: + return request_dict + + if instance_count == 1: + return request_dict + + extended_request = request_dict.copy() + + for input_channel in extended_request["InputDataConfig"]: + s3_data_source = input_channel["DataSource"].get("S3DataSource", None) + if s3_data_source: + s3_data_source["S3DataDistributionType"] = "FullyReplicated" + + return extended_request + + def _extend_spark_config_to_request( request_dict: Dict, job_settings: _JobSettings, diff --git a/src/sagemaker/remote_function/runtime_environment/__init__.py b/src/sagemaker/remote_function/runtime_environment/__init__.py index e69de29bb2..18557a2eb5 100644 --- a/src/sagemaker/remote_function/runtime_environment/__init__.py +++ b/src/sagemaker/remote_function/runtime_environment/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container_drivers directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 8fd83bfcfe..da7c493ae5 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -15,10 +15,14 @@ import argparse import getpass -import sys +import json +import multiprocessing import os -import shutil import pathlib +import shutil +import subprocess +import sys +from typing import Any, Dict if __package__ is None or __package__ == "": from runtime_environment_manager import ( @@ -39,64 +43,48 @@ REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws" BASE_CHANNEL_PATH = "/opt/ml/input/data" FAILURE_REASON_PATH = "/opt/ml/output/failure" -JOB_OUTPUT_DIRS = ["/opt/ml/output", "/opt/ml/model", "/tmp"] +JOB_OUTPUT_DIRS = ["/opt/ml/input", "/opt/ml/output", "/opt/ml/model", "/tmp"] PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh" JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace" SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies" +SM_MODEL_DIR = "/opt/ml/model" -logger = get_logger() +SM_INPUT_DIR = "/opt/ml/input" +SM_INPUT_DATA_DIR = "/opt/ml/input/data" +SM_INPUT_CONFIG_DIR = "/opt/ml/input/config" +SM_OUTPUT_DIR = "/opt/ml/output" +SM_OUTPUT_FAILURE = "/opt/ml/output/failure" +SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" -def main(sys_args=None): - """Entry point for bootstrap script""" - - exit_code = DEFAULT_FAILURE_CODE +SM_MASTER_ADDR = "algo-1" +SM_MASTER_PORT = 7777 - try: - args = _parse_args(sys_args) - client_python_version = args.client_python_version - client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version - job_conda_env = args.job_conda_env - pipeline_execution_id = args.pipeline_execution_id - dependency_settings = _DependencySettings.from_string(args.dependency_settings) - func_step_workspace = args.func_step_s3_dir +RESOURCE_CONFIG = f"{SM_INPUT_CONFIG_DIR}/resourceconfig.json" +ENV_OUTPUT_FILE = "/opt/ml/input/sm_training.env" - conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") +SENSITIVE_KEYWORDS = ["SECRET", "PASSWORD", "KEY", "TOKEN", "PRIVATE", "CREDS", "CREDENTIALS"] +HIDDEN_VALUE = "******" - RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) +SM_EFA_NCCL_INSTANCES = [ + "ml.g4dn.8xlarge", + "ml.g4dn.12xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.p5.48xlarge", + "ml.trn1.32xlarge", +] - user = getpass.getuser() - if user != "root": - log_message = ( - "The job is running on non-root user: %s. Adding write permissions to the " - "following job output directories: %s." - ) - logger.info(log_message, user, JOB_OUTPUT_DIRS) - RuntimeEnvironmentManager().change_dir_permission( - dirs=JOB_OUTPUT_DIRS, new_permission="777" - ) +SM_EFA_RDMA_INSTANCES = [ + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.trn1.32xlarge", +] - if pipeline_execution_id: - _bootstrap_runtime_env_for_pipeline_step( - client_python_version, func_step_workspace, conda_env, dependency_settings - ) - else: - _bootstrap_runtime_env_for_remote_function( - client_python_version, conda_env, dependency_settings - ) - - RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( - client_sagemaker_pysdk_version - ) - - exit_code = SUCCESS_EXIT_CODE - except Exception as e: # pylint: disable=broad-except - logger.exception("Error encountered while bootstrapping runtime environment: %s", e) - - _write_failure_reason_file(str(e)) - finally: - sys.exit(exit_code) +logger = get_logger() def _bootstrap_runtime_env_for_remote_function( @@ -283,9 +271,332 @@ def _parse_args(sys_args): parser.add_argument("--pipeline_execution_id", type=str) parser.add_argument("--dependency_settings", type=str) parser.add_argument("--func_step_s3_dir", type=str) + parser.add_argument("--distribution", type=str, default=None) + parser.add_argument("--user_nproc_per_node", type=str, default=None) args, _ = parser.parse_known_args(sys_args) return args +def log_key_value(key: str, value: str): + """Log a key-value pair, masking sensitive values if necessary.""" + if any(keyword.lower() in key.lower() for keyword in SENSITIVE_KEYWORDS): + logger.info("%s=%s", key, HIDDEN_VALUE) + elif isinstance(value, dict): + masked_value = mask_sensitive_info(value) + logger.info("%s=%s", key, json.dumps(masked_value)) + else: + try: + decoded_value = json.loads(value) + if isinstance(decoded_value, dict): + masked_value = mask_sensitive_info(decoded_value) + logger.info("%s=%s", key, json.dumps(masked_value)) + else: + logger.info("%s=%s", key, decoded_value) + except (json.JSONDecodeError, TypeError): + logger.info("%s=%s", key, value) + + +def log_env_variables(env_vars_dict: Dict[str, Any]): + """Log Environment Variables from the environment and an env_vars_dict.""" + for key, value in os.environ.items(): + log_key_value(key, value) + + for key, value in env_vars_dict.items(): + log_key_value(key, value) + + +def mask_sensitive_info(data): + """Recursively mask sensitive information in a dictionary.""" + if isinstance(data, dict): + for k, v in data.items(): + if isinstance(v, dict): + data[k] = mask_sensitive_info(v) + elif isinstance(v, str) and any( + keyword.lower() in k.lower() for keyword in SENSITIVE_KEYWORDS + ): + data[k] = HIDDEN_VALUE + return data + + +def num_cpus() -> int: + """Return the number of CPUs available in the current container. + + Returns: + int: Number of CPUs available in the current container. + """ + return multiprocessing.cpu_count() + + +def num_gpus() -> int: + """Return the number of GPUs available in the current container. + + Returns: + int: Number of GPUs available in the current container. + """ + try: + cmd = ["nvidia-smi", "--list-gpus"] + output = subprocess.check_output(cmd).decode("utf-8") + return sum(1 for line in output.splitlines() if line.startswith("GPU ")) + except (OSError, subprocess.CalledProcessError): + logger.info("No GPUs detected (normal if no gpus installed)") + return 0 + + +def num_neurons() -> int: + """Return the number of neuron cores available in the current container. + + Returns: + int: Number of Neuron Cores available in the current container. + """ + try: + cmd = ["neuron-ls", "-j"] + output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8") + j = json.loads(output) + neuron_cores = 0 + for item in j: + neuron_cores += item.get("nc_count", 0) + logger.info("Found %s neurons on this instance", neuron_cores) + return neuron_cores + except OSError: + logger.info("No Neurons detected (normal if no neurons installed)") + return 0 + except subprocess.CalledProcessError as e: + if e.output is not None: + try: + msg = e.output.decode("utf-8").partition("error=")[2] + logger.info( + "No Neurons detected (normal if no neurons installed). \ + If neuron installed then %s", + msg, + ) + except AttributeError: + logger.info("No Neurons detected (normal if no neurons installed)") + else: + logger.info("No Neurons detected (normal if no neurons installed)") + + return 0 + + +def safe_serialize(data): + """Serialize the data without wrapping strings in quotes. + + This function handles the following cases: + 1. If `data` is a string, it returns the string as-is without wrapping in quotes. + 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns + the JSON-encoded string using `json.dumps()`. + 3. If `data` cannot be serialized (e.g., a custom object), it returns the string + representation of the data using `str(data)`. + + Args: + data (Any): The data to serialize. + + Returns: + str: The serialized JSON-compatible string or the string representation of the input. + """ + if isinstance(data, str): + return data + try: + return json.dumps(data) + except TypeError: + return str(data) + + +def set_env( + resource_config: Dict[str, Any], + distribution: str = None, + user_nproc_per_node: bool = None, + output_file: str = ENV_OUTPUT_FILE, +): + """Set environment variables for the training job container. + + Args: + resource_config (Dict[str, Any]): Resource configuration for the training job. + output_file (str): Output file to write the environment variables. + """ + # Constants + env_vars = { + "SM_MODEL_DIR": SM_MODEL_DIR, + "SM_INPUT_DIR": SM_INPUT_DIR, + "SM_INPUT_DATA_DIR": SM_INPUT_DATA_DIR, + "SM_INPUT_CONFIG_DIR": SM_INPUT_CONFIG_DIR, + "SM_OUTPUT_DIR": SM_OUTPUT_DIR, + "SM_OUTPUT_FAILURE": SM_OUTPUT_FAILURE, + "SM_OUTPUT_DATA_DIR": SM_OUTPUT_DATA_DIR, + "SM_MASTER_ADDR": SM_MASTER_ADDR, + "SM_MASTER_PORT": SM_MASTER_PORT, + } + + # Host Variables + current_host = resource_config["current_host"] + current_instance_type = resource_config["current_instance_type"] + hosts = resource_config["hosts"] + sorted_hosts = sorted(hosts) + + env_vars["SM_CURRENT_HOST"] = current_host + env_vars["SM_CURRENT_INSTANCE_TYPE"] = current_instance_type + env_vars["SM_HOSTS"] = sorted_hosts + env_vars["SM_NETWORK_INTERFACE_NAME"] = resource_config["network_interface_name"] + env_vars["SM_HOST_COUNT"] = len(sorted_hosts) + env_vars["SM_CURRENT_HOST_RANK"] = sorted_hosts.index(current_host) + + env_vars["SM_NUM_CPUS"] = num_cpus() + env_vars["SM_NUM_GPUS"] = num_gpus() + env_vars["SM_NUM_NEURONS"] = num_neurons() + + # Misc. + env_vars["SM_RESOURCE_CONFIG"] = resource_config + + if user_nproc_per_node is not None and int(user_nproc_per_node) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(user_nproc_per_node) + else: + if int(env_vars["SM_NUM_GPUS"]) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_GPUS"]) + elif int(env_vars["SM_NUM_NEURONS"]) > 0: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_NEURONS"]) + else: + env_vars["SM_NPROC_PER_NODE"] = int(env_vars["SM_NUM_CPUS"]) + + # All Training Environment Variables + env_vars["SM_TRAINING_ENV"] = { + "current_host": env_vars["SM_CURRENT_HOST"], + "current_instance_type": env_vars["SM_CURRENT_INSTANCE_TYPE"], + "hosts": env_vars["SM_HOSTS"], + "host_count": env_vars["SM_HOST_COUNT"], + "nproc_per_node": env_vars["SM_NPROC_PER_NODE"], + "master_addr": env_vars["SM_MASTER_ADDR"], + "master_port": env_vars["SM_MASTER_PORT"], + "input_config_dir": env_vars["SM_INPUT_CONFIG_DIR"], + "input_data_dir": env_vars["SM_INPUT_DATA_DIR"], + "input_dir": env_vars["SM_INPUT_DIR"], + "job_name": os.environ["TRAINING_JOB_NAME"], + "model_dir": env_vars["SM_MODEL_DIR"], + "network_interface_name": env_vars["SM_NETWORK_INTERFACE_NAME"], + "num_cpus": env_vars["SM_NUM_CPUS"], + "num_gpus": env_vars["SM_NUM_GPUS"], + "num_neurons": env_vars["SM_NUM_NEURONS"], + "output_data_dir": env_vars["SM_OUTPUT_DATA_DIR"], + "resource_config": env_vars["SM_RESOURCE_CONFIG"], + } + + if distribution and distribution == "torchrun": + logger.info("Distribution: torchrun") + + instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] + network_interface_name = env_vars.get("SM_NETWORK_INTERFACE_NAME", "eth0") + + if instance_type in SM_EFA_NCCL_INSTANCES: + # Enable EFA use + env_vars["FI_PROVIDER"] = "efa" + if instance_type in SM_EFA_RDMA_INSTANCES: + # Use EFA's RDMA functionality for one-sided and two-sided transfer + env_vars["FI_EFA_USE_DEVICE_RDMA"] = "1" + env_vars["RDMAV_FORK_SAFE"] = "1" + env_vars["NCCL_SOCKET_IFNAME"] = str(network_interface_name) + env_vars["NCCL_PROTO"] = "simple" + elif distribution and distribution == "mpirun": + logger.info("Distribution: mpirun") + + env_vars["MASTER_ADDR"] = env_vars["SM_MASTER_ADDR"] + env_vars["MASTER_PORT"] = str(env_vars["SM_MASTER_PORT"]) + + host_list = [ + "{}:{}".format(host, int(env_vars["SM_NPROC_PER_NODE"])) for host in sorted_hosts + ] + env_vars["SM_HOSTS_LIST"] = ",".join(host_list) + + instance_type = env_vars["SM_CURRENT_INSTANCE_TYPE"] + + if instance_type in SM_EFA_NCCL_INSTANCES: + env_vars["SM_FI_PROVIDER"] = "-x FI_PROVIDER=efa" + env_vars["SM_NCCL_PROTO"] = "-x NCCL_PROTO=simple" + else: + env_vars["SM_FI_PROVIDER"] = "" + env_vars["SM_NCCL_PROTO"] = "" + + if instance_type in SM_EFA_RDMA_INSTANCES: + env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "-x FI_EFA_USE_DEVICE_RDMA=1" + else: + env_vars["SM_FI_EFA_USE_DEVICE_RDMA"] = "" + + with open(output_file, "w") as f: + for key, value in env_vars.items(): + f.write(f"export {key}='{safe_serialize(value)}'\n") + + logger.info("Environment Variables:") + log_env_variables(env_vars_dict=env_vars) + + +def main(sys_args=None): + """Entry point for bootstrap script""" + + exit_code = DEFAULT_FAILURE_CODE + + try: + args = _parse_args(sys_args) + + logger.info("Arguments:") + for arg in vars(args): + logger.info("%s=%s", arg, getattr(args, arg)) + + client_python_version = args.client_python_version + client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version + job_conda_env = args.job_conda_env + pipeline_execution_id = args.pipeline_execution_id + dependency_settings = _DependencySettings.from_string(args.dependency_settings) + func_step_workspace = args.func_step_s3_dir + distribution = args.distribution + user_nproc_per_node = args.user_nproc_per_node + + conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") + + RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) + + user = getpass.getuser() + if user != "root": + log_message = ( + "The job is running on non-root user: %s. Adding write permissions to the " + "following job output directories: %s." + ) + logger.info(log_message, user, JOB_OUTPUT_DIRS) + RuntimeEnvironmentManager().change_dir_permission( + dirs=JOB_OUTPUT_DIRS, new_permission="777" + ) + + if pipeline_execution_id: + _bootstrap_runtime_env_for_pipeline_step( + client_python_version, func_step_workspace, conda_env, dependency_settings + ) + else: + _bootstrap_runtime_env_for_remote_function( + client_python_version, conda_env, dependency_settings + ) + + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( + client_sagemaker_pysdk_version + ) + + if os.path.exists(RESOURCE_CONFIG): + try: + logger.info("Found %s", RESOURCE_CONFIG) + with open(RESOURCE_CONFIG, "r") as f: + resource_config = json.load(f) + set_env( + resource_config=resource_config, + distribution=distribution, + user_nproc_per_node=user_nproc_per_node, + ) + except (json.JSONDecodeError, FileNotFoundError) as e: + # Optionally, you might want to log this error + logger.info("ERROR: Error processing %s: %s", RESOURCE_CONFIG, str(e)) + + exit_code = SUCCESS_EXIT_CODE + except Exception as e: # pylint: disable=broad-except + logger.exception("Error encountered while bootstrapping runtime environment: %s", e) + + _write_failure_reason_file(str(e)) + finally: + sys.exit(exit_code) + + if __name__ == "__main__": main(sys.argv[1:]) diff --git a/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py b/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py new file mode 100644 index 0000000000..6f3897fb0b --- /dev/null +++ b/src/sagemaker/remote_function/runtime_environment/mpi_utils_remote.py @@ -0,0 +1,252 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""An utils function for runtime environment. This must be kept independent of SageMaker PySDK""" +from __future__ import absolute_import + +import argparse +import json +import os +import subprocess +import sys +import time +from typing import List + +import paramiko + +if __package__ is None or __package__ == "": + from runtime_environment_manager import ( + get_logger, + ) +else: + from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( + get_logger, + ) + +SUCCESS_EXIT_CODE = 0 +DEFAULT_FAILURE_CODE = 1 + +FINISHED_STATUS_FILE = "/tmp/done.algo-1" +READY_FILE = "/tmp/ready.%s" +DEFAULT_SSH_PORT = 22 + +FAILURE_REASON_PATH = "/opt/ml/output/failure" +FINISHED_STATUS_FILE = "/tmp/done.algo-1" + +logger = get_logger() + + +class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy): + """Class to handle host key policy for SageMaker distributed training SSH connections. + + Example: + >>> client = paramiko.SSHClient() + >>> client.set_missing_host_key_policy(CustomHostKeyPolicy()) + >>> # Will succeed for SageMaker algorithm containers + >>> client.connect('algo-1234.internal') + >>> # Will raise SSHException for other unknown hosts + >>> client.connect('unknown-host') # raises SSHException + """ + + def missing_host_key(self, client, hostname, key): + """Accept host keys for algo-* hostnames, reject others. + + Args: + client: The SSHClient instance + hostname: The hostname attempting to connect + key: The host key + Raises: + paramiko.SSHException: If hostname doesn't match algo-* pattern + """ + if hostname.startswith("algo-"): + client.get_host_keys().add(hostname, key.get_name(), key) + return + raise paramiko.SSHException(f"Unknown host key for {hostname}") + + +def _parse_args(sys_args): + """Parses CLI arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--job_ended", type=str, default="0") + args, _ = parser.parse_known_args(sys_args) + return args + + +def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool: + """Check if the connection to the provided host and port is possible.""" + try: + with paramiko.SSHClient() as client: + client.load_system_host_keys() + client.set_missing_host_key_policy(CustomHostKeyPolicy()) + client.connect(host, port=port) + logger.info("Can connect to host %s", host) + return True + except Exception as e: # pylint: disable=W0703 + logger.info("Cannot connect to host %s", host) + logger.debug("Connection failed with exception: %s", e) + return False + + +def _write_file_to_host(host: str, status_file: str) -> bool: + """Write the a file to the provided host.""" + try: + logger.info("Writing %s to %s", status_file, host) + subprocess.run( + ["ssh", host, "touch", f"{status_file}"], + capture_output=True, + text=True, + check=True, + ) + logger.info("Finished writing status file") + return True + except subprocess.CalledProcessError: + logger.info("Cannot connect to %s", host) + return False + + +def _write_failure_reason_file(failure_msg): + """Create a file 'failure' with failure reason written if bootstrap runtime env failed. + + See: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html + Args: + failure_msg: The content of file to be written. + """ + if not os.path.exists(FAILURE_REASON_PATH): + with open(FAILURE_REASON_PATH, "w") as f: + f.write("RuntimeEnvironmentError: " + failure_msg) + + +def _wait_for_master(master_host: str, port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Worker nodes wait until they can connect to the master node.""" + start_time = time.time() + while True: + logger.info("Worker is attempting to connect to the master node %s...", master_host) + if _can_connect(master_host, port): + logger.info("Worker can connect to master node %s.", master_host) + break + if time.time() - start_time > timeout: + raise TimeoutError("Timed out waiting for master %s to be reachable." % master_host) + + time.sleep(5) # Wait for 5 seconds before trying again + + +def _wait_for_status_file(status_file: str): + """Wait for the status file to be created.""" + logger.info("Waiting for status file %s", status_file) + while not os.path.exists(status_file): + time.sleep(30) + logger.info("Found status file %s", status_file) + + +def _wait_for_workers(worker_hosts: List[str], port: int = DEFAULT_SSH_PORT, timeout: int = 300): + """Master node waits until it can connect to all worker nodes.""" + start_time = time.time() + if not worker_hosts: + logger.info("No worker nodes to connect to.") + return + + while True: + logger.info("Master is attempting to connect to all workers...") + all_workers_connected = all( + _can_connect(worker, port) and os.path.exists(READY_FILE % worker) + for worker in worker_hosts + ) + + if all_workers_connected: + logger.info("Master can connect to all worker nodes.") + break + if time.time() - start_time > timeout: + raise TimeoutError("Timed out waiting for workers to be reachable.") + + time.sleep(5) # Wait for 5 seconds before trying again + + +def bootstrap_master_node(worker_hosts: List[str]): + """Bootstrap the master node.""" + logger.info("Bootstrapping master node...") + _wait_for_workers(worker_hosts) + + +def bootstrap_worker_node( + master_host: str, current_host: str, status_file: str = FINISHED_STATUS_FILE +): + """Bootstrap the worker nodes.""" + logger.info("Bootstrapping worker node...") + _wait_for_master(master_host) + _write_file_to_host(master_host, READY_FILE % current_host) + _wait_for_status_file(status_file) + + +def start_sshd_daemon(): + """Start the SSH daemon on the current node.""" + sshd_executable = "/usr/sbin/sshd" + + if not os.path.exists(sshd_executable): + raise RuntimeError("SSH daemon not found.") + + # Start the sshd in daemon mode (-D) + subprocess.Popen([sshd_executable, "-D"]) + logger.info("Started SSH daemon.") + + +def write_status_file_to_workers(worker_hosts: List[str], status_file: str = FINISHED_STATUS_FILE): + """Write the status file to all worker nodes.""" + for worker in worker_hosts: + retry = 0 + while not _write_file_to_host(worker, status_file): + time.sleep(5) + retry += 1 + if retry > 5: + raise TimeoutError("Timed out waiting for %s to be reachable." % worker) + logger.info("Retrying to write status file to %s", worker) + + +def main(sys_args=None): + """Entry point for bootstrap script""" + try: + args = _parse_args(sys_args) + + job_ended = args.job_ended + + main_host = os.environ["SM_MASTER_ADDR"] + current_host = os.environ["SM_CURRENT_HOST"] + + if job_ended == "0": + logger.info("Job is running, bootstrapping nodes") + + start_sshd_daemon() + + if current_host != main_host: + bootstrap_worker_node(main_host, current_host) + else: + sorted_hosts = json.loads(os.environ["SM_HOSTS"]) + worker_hosts = [host for host in sorted_hosts if host != main_host] + + bootstrap_master_node(worker_hosts) + else: + logger.info("Job ended, writing status file to workers") + + if current_host == main_host: + sorted_hosts = json.loads(os.environ["SM_HOSTS"]) + worker_hosts = [host for host in sorted_hosts if host != main_host] + + write_status_file_to_workers(worker_hosts) + except Exception as e: # pylint: disable=broad-except + logger.exception("Error encountered while bootstrapping runtime environment: %s", e) + + _write_failure_reason_file(str(e)) + + sys.exit(DEFAULT_FAILURE_CODE) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index df14ac558f..d0ddea4432 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -31,12 +31,14 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, instance_type: Optional[str] = None, + config_name: Optional[str] = None, ) -> ResourceRequirements: """Retrieves the default resource requirements for the model matching the given arguments. @@ -47,6 +49,8 @@ def retrieve_default( retrieve the default resource requirements. (Default: None). model_version (str): The version of the model for which to retrieve the default resource requirements. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -62,6 +66,7 @@ def retrieve_default( chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). instance_type (str): An instance type to optionally supply in order to get host requirements specific for the instance type. + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: str: The default resource requirements to use for the model. @@ -78,13 +83,15 @@ def retrieve_default( raise ValueError("Must specify scope for resource requirements.") return artifacts._retrieve_default_resources( - model_id, - model_version, - scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + scope=scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, instance_type=instance_type, + config_name=config_name, ) diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index 3ed539fa2e..f1e1407633 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -84,7 +84,7 @@ def __init__( hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, - **kwargs + **kwargs, ): """Creates an RLEstimator for managed Reinforcement Learning (RL). @@ -120,8 +120,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on @@ -168,7 +168,7 @@ def __init__( hyperparameters, image_uri=image_uri, metric_definitions=metric_definitions, - **kwargs + **kwargs, ) def create_model( @@ -178,7 +178,7 @@ def create_model( entry_point=None, source_dir=None, dependencies=None, - **kwargs + **kwargs, ): """Create a SageMaker ``RLEstimatorModel`` object that can be deployed to an Endpoint. diff --git a/src/sagemaker/s3_utils.py b/src/sagemaker/s3_utils.py index e53cdbe02a..f59c8a299f 100644 --- a/src/sagemaker/s3_utils.py +++ b/src/sagemaker/s3_utils.py @@ -45,6 +45,19 @@ def parse_s3_url(url): return parsed_url.netloc, parsed_url.path.lstrip("/") +def is_s3_url(url): + """Returns True if url is an s3 url, False if not + + Args: + url (str): + + Returns: + bool: + """ + parsed_url = urlparse(url) + return parsed_url.scheme == "s3" + + def s3_path_join(*args, with_end_slash: bool = False): """Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix). diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 9a1c4933d2..f280a627d2 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -20,6 +20,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -29,10 +30,13 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -42,6 +46,8 @@ def retrieve( retrieve the script S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). script_scope (str): The script type. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model @@ -55,6 +61,9 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). + model_type (JumpStartModelType): The type of the model, can be open weights model + or proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: str: The model script URI for the corresponding model. @@ -71,11 +80,14 @@ def retrieve( ) return artifacts._retrieve_script_uri( - model_id, - model_version, - script_scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + script_scope=script_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, + model_type=model_type, ) diff --git a/src/sagemaker/serializer_utils.py b/src/sagemaker/serializer_utils.py new file mode 100644 index 0000000000..96a931084c --- /dev/null +++ b/src/sagemaker/serializer_utils.py @@ -0,0 +1,222 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Placeholder docstring""" +from __future__ import absolute_import + +import logging +import struct +import sys + +import numpy as np + +from sagemaker.amazon.record_pb2 import Record +from sagemaker.utils import DeferredError + + +def _write_feature_tensor(resolved_type, record, vector): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.features["values"].int32_tensor.values.extend(vector) + elif resolved_type == "Float64": + record.features["values"].float64_tensor.values.extend(vector) + elif resolved_type == "Float32": + record.features["values"].float32_tensor.values.extend(vector) + + +def _write_label_tensor(resolved_type, record, scalar): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.label["values"].int32_tensor.values.extend([scalar]) + elif resolved_type == "Float64": + record.label["values"].float64_tensor.values.extend([scalar]) + elif resolved_type == "Float32": + record.label["values"].float32_tensor.values.extend([scalar]) + + +def _write_keys_tensor(resolved_type, record, vector): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.features["values"].int32_tensor.keys.extend(vector) + elif resolved_type == "Float64": + record.features["values"].float64_tensor.keys.extend(vector) + elif resolved_type == "Float32": + record.features["values"].float32_tensor.keys.extend(vector) + + +def _write_shape(resolved_type, record, scalar): + """Placeholder Docstring""" + if resolved_type == "Int32": + record.features["values"].int32_tensor.shape.extend([scalar]) + elif resolved_type == "Float64": + record.features["values"].float64_tensor.shape.extend([scalar]) + elif resolved_type == "Float32": + record.features["values"].float32_tensor.shape.extend([scalar]) + + +def write_numpy_to_dense_tensor(file, array, labels=None): + """Writes a numpy array to a dense tensor + + Args: + file: + array: + labels: + """ + + # Validate shape of array and labels, resolve array and label types + if not len(array.shape) == 2: + raise ValueError("Array must be a Matrix") + if labels is not None: + if not len(labels.shape) == 1: + raise ValueError("Labels must be a Vector") + if labels.shape[0] not in array.shape: + raise ValueError( + "Label shape {} not compatible with array shape {}".format( + labels.shape, array.shape + ) + ) + resolved_label_type = _resolve_type(labels.dtype) + resolved_type = _resolve_type(array.dtype) + + # Write each vector in array into a Record in the file object + record = Record() + for index, vector in enumerate(array): + record.Clear() + _write_feature_tensor(resolved_type, record, vector) + if labels is not None: + _write_label_tensor(resolved_label_type, record, labels[index]) + _write_recordio(file, record.SerializeToString()) + + +def write_spmatrix_to_sparse_tensor(file, array, labels=None): + """Writes a scipy sparse matrix to a sparse tensor + + Args: + file: + array: + labels: + """ + try: + import scipy + except ImportError as e: + logging.warning( + "scipy failed to import. Sparse matrix functions will be impaired or broken." + ) + # Any subsequent attempt to use scipy will raise the ImportError + scipy = DeferredError(e) + + if not scipy.sparse.issparse(array): + raise TypeError("Array must be sparse") + + # Validate shape of array and labels, resolve array and label types + if not len(array.shape) == 2: + raise ValueError("Array must be a Matrix") + if labels is not None: + if not len(labels.shape) == 1: + raise ValueError("Labels must be a Vector") + if labels.shape[0] not in array.shape: + raise ValueError( + "Label shape {} not compatible with array shape {}".format( + labels.shape, array.shape + ) + ) + resolved_label_type = _resolve_type(labels.dtype) + resolved_type = _resolve_type(array.dtype) + + csr_array = array.tocsr() + n_rows, n_cols = csr_array.shape + + record = Record() + for row_idx in range(n_rows): + record.Clear() + row = csr_array.getrow(row_idx) + # Write values + _write_feature_tensor(resolved_type, record, row.data) + # Write keys + _write_keys_tensor(resolved_type, record, row.indices.astype(np.uint64)) + + # Write labels + if labels is not None: + _write_label_tensor(resolved_label_type, record, labels[row_idx]) + + # Write shape + _write_shape(resolved_type, record, n_cols) + + _write_recordio(file, record.SerializeToString()) + + +def read_records(file): + """Eagerly read a collection of amazon Record protobuf objects from file. + + Args: + file: + """ + records = [] + for record_data in read_recordio(file): + record = Record() + record.ParseFromString(record_data) + records.append(record) + return records + + +# MXNet requires recordio records have length in bytes that's a multiple of 4 +# This sets up padding bytes to append to the end of the record, for diferent +# amounts of padding required. +padding = {} +for amount in range(4): + if sys.version_info >= (3,): + padding[amount] = bytes([0x00 for _ in range(amount)]) + else: + padding[amount] = bytearray([0x00 for _ in range(amount)]) + +_kmagic = 0xCED7230A + + +def _write_recordio(f, data): + """Writes a single data point as a RecordIO record to the given file. + + Args: + f: + data: + """ + length = len(data) + f.write(struct.pack("I", _kmagic)) + f.write(struct.pack("I", length)) + pad = (((length + 3) >> 2) << 2) - length + f.write(data) + f.write(padding[pad]) + + +def read_recordio(f): + """Placeholder Docstring""" + while True: + try: + (read_kmagic,) = struct.unpack("I", f.read(4)) + except struct.error: + return + assert read_kmagic == _kmagic + (len_record,) = struct.unpack("I", f.read(4)) + pad = (((len_record + 3) >> 2) << 2) - len_record + yield f.read(len_record) + if pad: + f.read(pad) + + +def _resolve_type(dtype): + """Placeholder Docstring""" + if dtype == np.dtype(int): + return "Int32" + if dtype == np.dtype(float): + return "Float64" + if dtype == np.dtype("float32"): + return "Float32" + raise ValueError("Unsupported dtype {} on array".format(dtype)) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index aefb52bd97..be46be0856 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -30,8 +30,10 @@ SparseMatrixSerializer, TorchTensorSerializer, StringSerializer, + RecordSerializer, ) +from sagemaker.deprecations import deprecated_class from sagemaker.jumpstart import artifacts, utils as jumpstart_utils from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartModelType @@ -42,9 +44,11 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model matching the given arguments. @@ -55,6 +59,8 @@ def retrieve_options( retrieve the supported serializers. (Default: None). model_version (str): The version of the model for which to retrieve the supported serializers. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -66,6 +72,7 @@ def retrieve_options( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: List[SimpleBaseSerializer]: The supported serializers to use for the model. @@ -79,12 +86,14 @@ def retrieve_options( ) return artifacts._retrieve_serializer_options( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + config_name=config_name, ) @@ -92,10 +101,12 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + config_name: Optional[str] = None, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. @@ -106,6 +117,8 @@ def retrieve_default( retrieve the default serializer. (Default: None). model_version (str): The version of the model for which to retrieve the default serializer. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -117,6 +130,7 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None). Returns: SimpleBaseSerializer: The default serializer to use for the model. @@ -130,11 +144,16 @@ def retrieve_default( ) return artifacts._retrieve_default_serializer( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, + config_name=config_name, ) + + +numpy_to_record_serializer = deprecated_class(RecordSerializer, "numpy_to_record_serializer") diff --git a/src/sagemaker/serve/builder/djl_builder.py b/src/sagemaker/serve/builder/djl_builder.py index e89c1b8e9c..9b1ebf1257 100644 --- a/src/sagemaker/serve/builder/djl_builder.py +++ b/src/sagemaker/serve/builder/djl_builder.py @@ -15,7 +15,6 @@ import logging from typing import Type from abc import ABC, abstractmethod -from pathlib import Path from datetime import datetime, timedelta from sagemaker.model import Model @@ -25,18 +24,19 @@ LocalModelOutOfMemoryException, LocalModelInvocationException, ) +from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.tuning import ( _serial_benchmark, _concurrent_benchmark, _more_performant, _pretty_print_results, ) +from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf from sagemaker.serve.model_server.djl_serving.utils import ( - _auto_detect_engine, - _set_serve_properties, _get_admissible_tensor_parallel_degrees, _get_admissible_dtypes, _get_default_tensor_parallel_degree, + _get_default_djl_configurations, ) from sagemaker.serve.utils.local_hardware import ( _get_nb_instance, @@ -45,24 +45,19 @@ _get_gpu_info_fallback, ) from sagemaker.serve.model_server.djl_serving.prepare import ( - prepare_for_djl_serving, _create_dir_structure, ) -from sagemaker.serve.utils.predictors import DjlLocalModePredictor -from sagemaker.serve.utils.types import ModelServer, _DjlEngine +from sagemaker.serve.utils.predictors import InProcessModePredictor, DjlLocalModePredictor +from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.telemetry_logger import _capture_telemetry -from sagemaker.djl_inference.model import ( - DeepSpeedModel, - FasterTransformerModel, - HuggingFaceAccelerateModel, -) +from sagemaker.djl_inference.model import DJLModel from sagemaker.base_predictor import PredictorBase logger = logging.getLogger(__name__) +LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS] # Match JumpStart DJL entrypoint format -_DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py" _CODE_FOLDER = "code" _INVALID_SAMPLE_DATA_EX = ( 'For djl-serving, sample input must be of {"inputs": str, "parameters": dict}, ' @@ -88,26 +83,25 @@ def __init__(self): self.vpc_config = None self._original_deploy = None self.secret_key = None - self.engine = None self.hf_model_config = None self._default_tensor_parallel_degree = None self._default_data_type = None self._default_max_tokens = None - self._default_max_new_tokens = None self.pysdk_model = None - self.overwrite_props_from_file = None self.schema_builder = None self.env_vars = None self.nb_instance_type = None self.ram_usage_model_load = None + self.role_arn = None + self.name = None @abstractmethod def _prepare_for_mode(self): - """Placeholder docstring""" + """Abstract method""" @abstractmethod def _get_client_translators(self): - """Placeholder docstring""" + """Abstract method""" def _is_djl(self): """Placeholder docstring""" @@ -130,37 +124,16 @@ def _validate_djl_serving_sample_data(self): def _create_djl_model(self) -> Type[Model]: """Placeholder docstring""" - code_dir = str(Path(self.model_path).joinpath(_CODE_FOLDER)) - - kwargs = { - "model_id": self.model, - "role": self.serve_settings.role_arn, - "entry_point": _DJL_MODEL_BUILDER_ENTRY_POINT, - "dtype": self._default_data_type, - "sagemaker_session": self.sagemaker_session, - "source_dir": code_dir, - "env": self.env_vars, - "hf_hub_token": self.env_vars.get("HUGGING_FACE_HUB_TOKEN"), - "image_config": self.image_config, - "vpc_config": self.vpc_config, - } - - if self.engine == _DjlEngine.DEEPSPEED: - pysdk_model = DeepSpeedModel( - tensor_parallel_degree=self._default_tensor_parallel_degree, - max_tokens=self._default_max_tokens, - **kwargs, - ) - elif self.engine == _DjlEngine.FASTER_TRANSFORMER: - pysdk_model = FasterTransformerModel( - tensor_parallel_degree=self._default_tensor_parallel_degree, - **kwargs, - ) - else: - pysdk_model = HuggingFaceAccelerateModel( - number_of_partitions=self._default_tensor_parallel_degree, - **kwargs, - ) + pysdk_model = DJLModel( + model_id=self.model, + role=self.serve_settings.role_arn, + sagemaker_session=self.sagemaker_session, + env=self.env_vars, + huggingface_hub_token=self.env_vars.get("HF_TOKEN"), + image_config=self.image_config, + vpc_config=self.vpc_config, + name=self.name, + ) if not self.image_uri: self.image_uri = pysdk_model.serving_image_uri(self.sagemaker_session.boto_region_name) @@ -174,7 +147,7 @@ def _create_djl_model(self) -> Type[Model]: @_capture_telemetry("djl.deploy") def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: - """Placeholder docstring""" + """Returns predictor depending on local mode or endpoint mode""" timeout = kwargs.get("model_data_download_timeout") if timeout: self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(timeout)}) @@ -196,7 +169,6 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa else: raise ValueError("Mode %s is not supported!" % overwrite_mode) - manual_set_props = None if self.mode == Mode.SAGEMAKER_ENDPOINT: if self.nb_instance_type and "instance_type" not in kwargs: kwargs.update({"instance_type": self.nb_instance_type}) @@ -212,20 +184,24 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa default_tensor_parallel_degree = _get_default_tensor_parallel_degree( self.hf_model_config, tot_gpus ) - manual_set_props = { - "option.tensor_parallel_degree": str(default_tensor_parallel_degree) + "\n" - } - - prepare_for_djl_serving( - model_path=self.model_path, - model=self.pysdk_model, - dependencies=self.dependencies, - overwrite_props_from_file=self.overwrite_props_from_file, - manual_set_props=manual_set_props, - ) + self.pysdk_model.env.update( + {"TENSOR_PARALLEL_DEGREE": str(default_tensor_parallel_degree)} + ) serializer = self.schema_builder.input_serializer deserializer = self.schema_builder._output_deserializer + + if self.mode == Mode.IN_PROCESS: + + predictor = InProcessModePredictor( + self.modes[str(Mode.IN_PROCESS)], serializer, deserializer + ) + + self.modes[str(Mode.IN_PROCESS)].create_server( + predictor, + ) + return predictor + if self.mode == Mode.LOCAL_CONTAINER: timeout = kwargs.get("model_data_download_timeout") @@ -239,7 +215,7 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa timeout if timeout else 1800, self.secret_key, predictor, - self.env_vars, + self.pysdk_model.env, ) ram_usage_after = _get_ram_usage_mb() @@ -254,9 +230,10 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa del kwargs["role"] # set model_data to uncompressed s3 dict - self.pysdk_model.model_data, env_vars = self._prepare_for_mode() - self.env_vars.update(env_vars) - self.pysdk_model.env.update(self.env_vars) + if not _is_optimized(self.pysdk_model): + self.pysdk_model.model_data, env_vars = self._prepare_for_mode() + self.env_vars.update(env_vars) + self.pysdk_model.env.update(self.env_vars) # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: @@ -265,6 +242,7 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa # if has not been built for local container we must use cache # that hosting has write access to. self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp" + self.pysdk_model.env["HF_HOME"] = "/tmp" self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp" if "endpoint_logging" not in kwargs: @@ -280,28 +258,26 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa def _build_for_hf_djl(self): """Placeholder docstring""" - self.overwrite_props_from_file = True self.nb_instance_type = _get_nb_instance() _create_dir_structure(self.model_path) - self.engine, self.hf_model_config = _auto_detect_engine( - self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - ) - if not hasattr(self, "pysdk_model"): - ( - self._default_tensor_parallel_degree, - self._default_data_type, - _, - self._default_max_tokens, - self._default_max_new_tokens, - ) = _set_serve_properties(self.hf_model_config, self.schema_builder) + self.env_vars.update({"HF_MODEL_ID": self.model}) + + self.hf_model_config = _get_model_config_properties_from_hf( + self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_TOKEN") + ) + default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations( + self.model, self.hf_model_config, self.schema_builder + ) + self.env_vars.update(default_djl_configurations) self.schema_builder.sample_input["parameters"][ "max_new_tokens" - ] = self._default_max_new_tokens + ] = _default_max_new_tokens + self.pysdk_model = self._create_djl_model() - if self.mode == Mode.LOCAL_CONTAINER: + if self.mode in LOCAL_MODES: self._prepare_for_mode() return self.pysdk_model @@ -315,8 +291,6 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): ) return self.pysdk_model - self.overwrite_props_from_file = False - admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees( self.hf_model_config ) @@ -336,8 +310,9 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): "Trying tensor parallel degree: %s, dtype: %s...", tensor_parallel_degree, dtype ) - self._default_tensor_parallel_degree = tensor_parallel_degree - self._default_data_type = dtype + self.env_vars.update( + {"TENSOR_PARALLEL_DEGREE": str(tensor_parallel_degree), "OPTION_DTYPE": dtype} + ) self.pysdk_model = self._create_djl_model() try: @@ -352,15 +327,15 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): predictor, self.schema_builder.sample_input ) - serving_properties = self.pysdk_model.generate_serving_properties() + tested_env = self.pysdk_model.env.copy() logger.info( "Average latency: %s, throughput/s: %s for configuration: %s", avg_latency, throughput_per_second, - serving_properties, + tested_env, ) benchmark_results[avg_latency] = [ - serving_properties, + tested_env, p90, avg_tokens_per_second, throughput_per_second, @@ -448,6 +423,12 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): if best_tuned_combination: self._default_tensor_parallel_degree = best_tuned_combination[1] self._default_data_type = best_tuned_combination[2] + self.env_vars.update( + { + "TENSOR_PARALLEL_DEGREE": str(self._default_tensor_parallel_degree), + "OPTION_DTYPE": self._default_data_type, + } + ) self.pysdk_model = self._create_djl_model() _pretty_print_results(benchmark_results) @@ -455,7 +436,7 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): "Model Configuration: %s was most performant with avg latency: %s, " "p90 latency: %s, average tokens per second: %s, throughput/s: %s, " "standard deviation of request %s", - self.pysdk_model.generate_serving_properties(), + self.pysdk_model.env, best_tuned_combination[0], best_tuned_combination[3], best_tuned_combination[4], @@ -463,40 +444,32 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): best_tuned_combination[6], ) else: - ( - self._default_tensor_parallel_degree, - self._default_data_type, - _, - self._default_max_tokens, - self._default_max_new_tokens, - ) = _set_serve_properties(self.hf_model_config, self.schema_builder) + default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations( + self.model, self.hf_model_config, self.schema_builder + ) + self.env_vars.update(default_djl_configurations) self.schema_builder.sample_input["parameters"][ "max_new_tokens" - ] = self._default_max_new_tokens + ] = _default_max_new_tokens self.pysdk_model = self._create_djl_model() logger.debug( "Failed to gather any tuning results. " "Please inspect the stack trace emitted from live logging for more details. " "Falling back to default serving.properties: %s", - self.pysdk_model.generate_serving_properties(), + self.pysdk_model.env, ) - prepare_for_djl_serving( - model_path=self.model_path, - model=self.pysdk_model, - dependencies=self.dependencies, - overwrite_props_from_file=self.overwrite_props_from_file, - ) - self.overwrite_props_from_file = True - return self.pysdk_model def _build_for_djl(self): """Placeholder docstring""" self._validate_djl_serving_sample_data() self.secret_key = None - self.pysdk_model = self._build_for_hf_djl() self.pysdk_model.tune = self._tune_for_hf_djl + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index e3368869fe..bf6fcaa376 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -14,15 +14,20 @@ from __future__ import absolute_import import copy +import re from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import Type +from typing import Type, Any, List, Dict, Optional, Tuple import logging +from botocore.exceptions import ClientError + +from sagemaker.enums import Tag from sagemaker.model import Model from sagemaker import model_uris from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees +from sagemaker.serve.model_server.multi_model_server.prepare import prepare_mms_js_resources from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.exceptions import ( @@ -32,9 +37,25 @@ LocalModelLoadException, SkipTuningComboException, ) +from sagemaker.serve.utils.optimize_utils import ( + _generate_model_source, + _update_environment_variables, + _extract_speculative_draft_model_provider, + _is_image_compatible_with_optimization_job, + _generate_channel_name, + _extract_optimization_config_and_env, + _is_optimized, + _custom_speculative_decoding, + SPECULATIVE_DRAFT_MODEL, + _is_inferentia_or_trainium, + _jumpstart_speculative_decoding, + _deployment_config_contains_draft_model, + _is_draft_model_jumpstart_provided, +) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, TgiLocalModePredictor, + TransformersLocalModePredictor, ) from sagemaker.serve.utils.local_hardware import ( _get_nb_instance, @@ -51,6 +72,7 @@ from sagemaker.serve.utils.types import ModelServer from sagemaker.base_predictor import PredictorBase from sagemaker.jumpstart.model import JumpStartModel +from sagemaker.utils import Tags _DJL_MODEL_BUILDER_ENTRY_POINT = "inference.py" _NO_JS_MODEL_EX = "HuggingFace JumpStart Model ID not detected. Building for HuggingFace Model ID." @@ -60,6 +82,7 @@ ModelServer.DJL_SERVING, ModelServer.TGI, } +_JS_MINIMUM_VERSION_IMAGE = "{}:0.31.0-lmi13.0.0-cu124" logger = logging.getLogger(__name__) @@ -90,13 +113,23 @@ def __init__(self): self.existing_properties = None self.prepared_for_tgi = None self.prepared_for_djl = None + self.prepared_for_mms = None self.schema_builder = None + self.instance_type = None self.nb_instance_type = None self.ram_usage_model_load = None - self.jumpstart = None + self.model_hub = None + self.model_metadata = None + self.role_arn = None + self.is_fine_tuned = None + self.is_compiled = False + self.is_quantized = False + self.speculative_decoding_draft_model_source = None + self.deployment_config_name = None + self.name = None @abstractmethod - def _prepare_for_mode(self): + def _prepare_for_mode(self, **kwargs): """Placeholder docstring""" @abstractmethod @@ -105,6 +138,9 @@ def _get_client_translators(self): def _is_jumpstart_model_id(self) -> bool: """Placeholder docstring""" + if self.model is None: + return False + try: model_uris.retrieve(model_id=self.model, model_version="*", model_scope=_JS_SCOPE) except KeyError: @@ -116,8 +152,13 @@ def _is_jumpstart_model_id(self) -> bool: def _create_pre_trained_js_model(self) -> Type[Model]: """Placeholder docstring""" - pysdk_model = JumpStartModel(self.model, vpc_config=self.vpc_config) - pysdk_model.sagemaker_session = self.sagemaker_session + pysdk_model = JumpStartModel( + self.model, + vpc_config=self.vpc_config, + sagemaker_session=self.sagemaker_session, + name=self.name, + instance_type=self.instance_type, + ) self._original_deploy = pysdk_model.deploy pysdk_model.deploy = self._js_builder_deploy_wrapper @@ -126,6 +167,7 @@ def _create_pre_trained_js_model(self) -> Type[Model]: @_capture_telemetry("jumpstart.deploy") def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: """Placeholder docstring""" + env = {} if "mode" in kwargs and kwargs.get("mode") != self.mode: overwrite_mode = kwargs.get("mode") # mode overwritten by customer during model.deploy() @@ -137,8 +179,13 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: if overwrite_mode == Mode.SAGEMAKER_ENDPOINT: self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT - if not hasattr(self, "prepared_for_djl") or not hasattr(self, "prepared_for_tgi"): - self.pysdk_model.model_data, env = self._prepare_for_mode() + if ( + not hasattr(self, "prepared_for_djl") + or not hasattr(self, "prepared_for_tgi") + or not hasattr(self, "prepared_for_mms") + ): + if not _is_optimized(self.pysdk_model): + self.pysdk_model.model_data, env = self._prepare_for_mode() elif overwrite_mode == Mode.LOCAL_CONTAINER: self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER @@ -160,9 +207,15 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: dependencies=self.dependencies, model_data=self.pysdk_model.model_data, ) + elif not hasattr(self, "prepared_for_mms"): + self.js_model_config, self.prepared_for_mms = prepare_mms_js_resources( + model_path=self.model_path, + js_id=self.model, + dependencies=self.dependencies, + model_data=self.pysdk_model.model_data, + ) self._prepare_for_mode() - env = {} else: raise ValueError("Mode %s is not supported!" % overwrite_mode) @@ -179,6 +232,10 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: predictor = TgiLocalModePredictor( self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer ) + elif self.model_server == ModelServer.MMS: + predictor = TransformersLocalModePredictor( + self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer + ) ram_usage_before = _get_ram_usage_mb() self.modes[str(Mode.LOCAL_CONTAINER)].create_server( @@ -231,7 +288,7 @@ def _build_for_djl_jumpstart(self): ) self._prepare_for_mode() elif self.mode == Mode.SAGEMAKER_ENDPOINT and hasattr(self, "prepared_for_djl"): - self.nb_instance_type = _get_nb_instance() + self.nb_instance_type = self.instance_type or _get_nb_instance() self.pysdk_model.model_data, env = self._prepare_for_mode() self.pysdk_model.env.update(env) @@ -254,6 +311,24 @@ def _build_for_tgi_jumpstart(self): self.pysdk_model.env.update(env) + def _build_for_mms_jumpstart(self): + """Placeholder docstring""" + + env = {} + if self.mode == Mode.LOCAL_CONTAINER: + if not hasattr(self, "prepared_for_mms"): + self.js_model_config, self.prepared_for_mms = prepare_mms_js_resources( + model_path=self.model_path, + js_id=self.model, + dependencies=self.dependencies, + model_data=self.pysdk_model.model_data, + ) + self._prepare_for_mode() + elif self.mode == Mode.SAGEMAKER_ENDPOINT and hasattr(self, "prepared_for_mms"): + self.pysdk_model.model_data, env = self._prepare_for_mode() + + self.pysdk_model.env.update(env) + def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800): """Tune for Jumpstart Models in Local Mode. @@ -264,7 +339,7 @@ def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800) returns: Tuned Model. """ - if self.mode != Mode.LOCAL_CONTAINER: + if self.mode == Mode.SAGEMAKER_ENDPOINT: logger.warning( "Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER ) @@ -431,18 +506,134 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration ) + def set_deployment_config(self, config_name: str, instance_type: str) -> None: + """Sets the deployment config to apply to the model. + + Args: + config_name (str): + The name of the deployment config to apply to the model. + Call list_deployment_configs to see the list of config names. + instance_type (str): + The instance_type that the model will use after setting + the config. + """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + raise Exception("Cannot set deployment config to an uninitialized model.") + + self.pysdk_model.set_deployment_config(config_name, instance_type) + self.deployment_config_name = config_name + + self.instance_type = instance_type + + # JS-benchmarked models only include SageMaker-provided SD models + if self.pysdk_model.additional_model_data_sources: + self.speculative_decoding_draft_model_source = "sagemaker" + self.pysdk_model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "sagemaker"}, + ) + self.pysdk_model.remove_tag_with_key(Tag.OPTIMIZATION_JOB_NAME) + self.pysdk_model.remove_tag_with_key(Tag.FINE_TUNING_MODEL_PATH) + self.pysdk_model.remove_tag_with_key(Tag.FINE_TUNING_JOB_NAME) + + def get_deployment_config(self) -> Optional[Dict[str, Any]]: + """Gets the deployment config to apply to the model. + + Returns: + Optional[Dict[str, Any]]: Deployment config to apply to this model. + """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self._build_for_jumpstart() + + return self.pysdk_model.deployment_config + + def display_benchmark_metrics(self, **kwargs): + """Display Markdown Benchmark Metrics for deployment configs.""" + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self._build_for_jumpstart() + + self.pysdk_model.display_benchmark_metrics(**kwargs) + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for ``This`` model in the current region. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + if not hasattr(self, "pysdk_model") or self.pysdk_model is None: + self._build_for_jumpstart() + + return self.pysdk_model.list_deployment_configs() + + def _is_fine_tuned_model(self) -> bool: + """Checks whether a fine-tuned model exists.""" + return self.model_metadata and ( + self.model_metadata.get("FINE_TUNING_MODEL_PATH") + or self.model_metadata.get("FINE_TUNING_JOB_NAME") + ) + + def _update_model_data_for_fine_tuned_model(self, pysdk_model: Type[Model]) -> Type[Model]: + """Set the model path and data and add fine-tuning tags for the model.""" + # TODO: determine precedence of FINE_TUNING_MODEL_PATH and FINE_TUNING_JOB_NAME + if fine_tuning_model_path := self.model_metadata.get("FINE_TUNING_MODEL_PATH"): + if not re.match("^(https|s3)://([^/]+)/?(.*)$", fine_tuning_model_path): + raise ValueError( + f"Invalid path for FINE_TUNING_MODEL_PATH: {fine_tuning_model_path}." + ) + pysdk_model.model_data["S3DataSource"]["S3Uri"] = fine_tuning_model_path + pysdk_model.add_tags( + {"Key": Tag.FINE_TUNING_MODEL_PATH, "Value": fine_tuning_model_path} + ) + logger.info( + "FINE_TUNING_MODEL_PATH detected. Using fine-tuned model found in %s.", + fine_tuning_model_path, + ) + return pysdk_model + + if fine_tuning_job_name := self.model_metadata.get("FINE_TUNING_JOB_NAME"): + try: + response = self.sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=fine_tuning_job_name + ) + fine_tuning_model_path = response["ModelArtifacts"]["S3ModelArtifacts"] + pysdk_model.model_data["S3DataSource"]["S3Uri"] = fine_tuning_model_path + pysdk_model.add_tags( + [ + {"key": Tag.FINE_TUNING_JOB_NAME, "value": fine_tuning_job_name}, + {"key": Tag.FINE_TUNING_MODEL_PATH, "value": fine_tuning_model_path}, + ] + ) + logger.info( + "FINE_TUNING_JOB_NAME detected. Using fine-tuned model found in %s.", + fine_tuning_model_path, + ) + return pysdk_model + except ClientError: + raise ValueError( + f"Invalid job name for FINE_TUNING_JOB_NAME: {fine_tuning_job_name}." + ) + + raise ValueError( + "Input model not found. Please provide either `model_path`, or " + "`FINE_TUNING_MODEL_PATH` or `FINE_TUNING_JOB_NAME` under `model_metadata`." + ) + def _build_for_jumpstart(self): """Placeholder docstring""" + if hasattr(self, "pysdk_model") and self.pysdk_model is not None: + return self.pysdk_model + # we do not pickle for jumpstart. set to none self.secret_key = None - self.jumpstart = True pysdk_model = self._create_pre_trained_js_model() - image_uri = pysdk_model.image_uri logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri) + if self._is_fine_tuned_model(): + self.is_fine_tuned = True + pysdk_model = self._update_model_data_for_fine_tuned_model(pysdk_model) + if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError( "JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode." @@ -451,7 +642,6 @@ def _build_for_jumpstart(self): if "djl-inference" in image_uri: logger.info("Building for DJL JumpStart Model ID...") self.model_server = ModelServer.DJL_SERVING - self.pysdk_model = pysdk_model self.image_uri = self.pysdk_model.image_uri @@ -461,21 +651,196 @@ def _build_for_jumpstart(self): elif "tgi-inference" in image_uri: logger.info("Building for TGI JumpStart Model ID...") self.model_server = ModelServer.TGI - self.pysdk_model = pysdk_model self.image_uri = self.pysdk_model.image_uri self._build_for_tgi_jumpstart() self.pysdk_model.tune = self.tune_for_tgi_jumpstart - else: + elif "huggingface-pytorch-inference:" in image_uri: + logger.info("Building for MMS JumpStart Model ID...") + self.model_server = ModelServer.MMS + self.pysdk_model = pysdk_model + self.image_uri = self.pysdk_model.image_uri + + self._build_for_mms_jumpstart() + elif self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError( - "JumpStart Model ID was not packaged with djl-inference or tgi-inference container." + "JumpStart Model ID was not packaged " + "with djl-inference, tgi-inference, or mms-inference container." ) + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model - def _is_gated_model(self, model) -> bool: + def _optimize_for_jumpstart( + self, + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + accept_eula: Optional[bool] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = None, + ) -> Optional[Dict[str, Any]]: + """Runs a model optimization job. + + Args: + output_path (Optional[str]): Specifies where to store the compiled/quantized model. + instance_type (str): Target deployment instance type that the model is optimized for. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + ``None``. + + Returns: + Dict[str, Any]: Model optimization job input arguments. + """ + if self._is_gated_model() and accept_eula is not True: + raise ValueError( + f"Model '{self.model}' requires accepting end-user license agreement (EULA)." + ) + + is_compilation = (compilation_config is not None) or _is_inferentia_or_trainium( + instance_type + ) + + pysdk_model_env_vars = dict() + if is_compilation: + pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type) + + # optimization_config can contain configs for both quantization and compilation + ( + optimization_config, + quantization_override_env, + compilation_override_env, + sharding_override_env, + ) = _extract_optimization_config_and_env( + quantization_config, compilation_config, sharding_config + ) + + if not optimization_config: + optimization_config = {} + + if not optimization_config.get("ModelCompilationConfig") and is_compilation: + # Fallback to default if override_env is None or empty + if not compilation_override_env: + compilation_override_env = pysdk_model_env_vars + + # Update optimization_config with ModelCompilationConfig + override_compilation_config = ( + {"OverrideEnvironment": compilation_override_env} + if compilation_override_env + else {} + ) + optimization_config["ModelCompilationConfig"] = override_compilation_config + + if speculative_decoding_config: + self._set_additional_model_source(speculative_decoding_config) + else: + deployment_config = self._find_compatible_deployment_config(None) + if deployment_config: + self.pysdk_model.set_deployment_config( + config_name=deployment_config.get("DeploymentConfigName"), + instance_type=deployment_config.get("InstanceType"), + ) + pysdk_model_env_vars = self.pysdk_model.env + + model_source = _generate_model_source(self.pysdk_model.model_data, accept_eula) + optimization_env_vars = _update_environment_variables(pysdk_model_env_vars, env_vars) + + output_config = {"S3OutputLocation": output_path} + if kms_key: + output_config["KmsKeyId"] = kms_key + + deployment_config_instance_type = ( + self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get("InstanceType") + if self.pysdk_model.deployment_config + else None + ) + self.instance_type = instance_type or deployment_config_instance_type or _get_nb_instance() + + create_optimization_job_args = { + "OptimizationJobName": job_name, + "ModelSource": model_source, + "DeploymentInstanceType": self.instance_type, + "OptimizationConfigs": [{k: v} for k, v in optimization_config.items()], + "OutputConfig": output_config, + "RoleArn": self.role_arn, + } + + if optimization_env_vars: + create_optimization_job_args["OptimizationEnvironment"] = optimization_env_vars + if max_runtime_in_sec: + create_optimization_job_args["StoppingCondition"] = { + "MaxRuntimeInSeconds": max_runtime_in_sec + } + if tags: + create_optimization_job_args["Tags"] = tags + if vpc_config: + create_optimization_job_args["VpcConfig"] = vpc_config + + if accept_eula: + self.pysdk_model.accept_eula = accept_eula + if isinstance(self.pysdk_model.model_data, dict): + self.pysdk_model.model_data["S3DataSource"]["ModelAccessConfig"] = { + "AcceptEula": True + } + + optimization_env_vars = _update_environment_variables( + optimization_env_vars, + { + **(quantization_override_env or {}), + **(compilation_override_env or {}), + **(sharding_override_env or {}), + }, + ) + if optimization_env_vars: + self.pysdk_model.env.update(optimization_env_vars) + + if sharding_config and self.pysdk_model._enable_network_isolation: + logger.warning( + "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " + "Loading of model requires network access. Setting it to False." + ) + self.pysdk_model._enable_network_isolation = False + + if quantization_config or sharding_config or is_compilation: + # only apply default image for vLLM usecases. + # vLLM does not support compilation for now so skip on compilation + return ( + create_optimization_job_args + if is_compilation + else self._set_optimization_image_default(create_optimization_job_args) + ) + return None + + def _is_gated_model(self, model=None) -> bool: """Determine if ``this`` Model is Gated Args: @@ -483,10 +848,251 @@ def _is_gated_model(self, model) -> bool: Returns: bool: ``True`` if ``this`` Model is Gated """ - s3_uri = model.model_data + s3_uri = model.model_data if model else self.pysdk_model.model_data if isinstance(s3_uri, dict): s3_uri = s3_uri.get("S3DataSource").get("S3Uri") if s3_uri is None: return False return "private" in s3_uri + + def _set_additional_model_source( + self, speculative_decoding_config: Optional[Dict[str, Any]] = None + ) -> None: + """Set Additional Model Source to ``this`` model. + + Args: + speculative_decoding_config (Optional[Dict[str, Any]]): Speculative decoding config. + accept_eula (Optional[bool]): For models that require a Model Access Config. + """ + if speculative_decoding_config: + model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) + + channel_name = _generate_channel_name(self.pysdk_model.additional_model_data_sources) + + if model_provider in ["sagemaker", "auto"]: + additional_model_data_sources = ( + self.pysdk_model.deployment_config.get("DeploymentArgs", {}).get( + "AdditionalDataSources" + ) + if self.pysdk_model.deployment_config + else None + ) + if additional_model_data_sources is None: + deployment_config = self._find_compatible_deployment_config( + speculative_decoding_config + ) + if deployment_config: + if model_provider == "sagemaker" and _is_draft_model_jumpstart_provided( + deployment_config + ): + raise ValueError( + "No `Sagemaker` provided draft model was found for " + f"{self.model}. Try setting `ModelProvider` " + "to `Auto` instead." + ) + + try: + self.pysdk_model.set_deployment_config( + config_name=deployment_config.get("DeploymentConfigName"), + instance_type=deployment_config.get("InstanceType"), + ) + except ValueError as e: + raise ValueError( + f"{e} If using speculative_decoding_config, " + "accept the EULA by setting `AcceptEula`=True." + ) + else: + raise ValueError( + "Cannot find deployment config compatible for optimization job." + ) + else: + if model_provider == "sagemaker" and _is_draft_model_jumpstart_provided( + self.pysdk_model.deployment_config + ): + raise ValueError( + "No `Sagemaker` provided draft model was found for " + f"{self.model}. Try setting `ModelProvider` " + "to `Auto` instead." + ) + + self.pysdk_model.env.update( + {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"} + ) + self.pysdk_model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": model_provider}, + ) + elif model_provider == "jumpstart": + _jumpstart_speculative_decoding( + model=self.pysdk_model, + speculative_decoding_config=speculative_decoding_config, + sagemaker_session=self.sagemaker_session, + ) + else: + self.pysdk_model = _custom_speculative_decoding( + self.pysdk_model, + speculative_decoding_config, + speculative_decoding_config.get("AcceptEula", False), + ) + + def _find_compatible_deployment_config( + self, speculative_decoding_config: Optional[Dict] = None + ) -> Optional[Dict[str, Any]]: + """Finds compatible model deployment config for optimization job. + + Args: + speculative_decoding_config (Optional[Dict]): Speculative decoding config. + + Returns: + Optional[Dict[str, Any]]: A compatible model deployment config for optimization job. + """ + model_provider = _extract_speculative_draft_model_provider(speculative_decoding_config) + for deployment_config in self.pysdk_model.list_deployment_configs(): + image_uri = deployment_config.get("deployment_config", {}).get("ImageUri") + + if _is_image_compatible_with_optimization_job( + image_uri + ) and _deployment_config_contains_draft_model(deployment_config): + if ( + model_provider in ["sagemaker", "auto"] + and deployment_config.get("DeploymentArgs", {}).get("AdditionalDataSources") + ) or model_provider == "custom": + return deployment_config + + # There's no matching config from jumpstart to add sagemaker draft model location + if model_provider in ["sagemaker", "auto"]: + return None + + # fall back to the default jumpstart model deployment config for optimization job + return self.pysdk_model.deployment_config + + def _get_neuron_model_env_vars( + self, instance_type: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """Gets Neuron model env vars. + + Args: + instance_type (Optional[str]): Instance type. + + Returns: + Optional[Dict[str, Any]]: Neuron Model environment variables. + """ + metadata_configs = self.pysdk_model._metadata_configs + if metadata_configs: + metadata_config = metadata_configs.get(self.pysdk_model.config_name) + resolve_config = metadata_config.resolved_config if metadata_config else None + if resolve_config and instance_type not in resolve_config.get( + "supported_inference_instance_types", [] + ): + neuro_model_id = resolve_config.get("hosting_neuron_model_id") + neuro_model_version = resolve_config.get("hosting_neuron_model_version", "*") + if neuro_model_id: + job_model = JumpStartModel( + neuro_model_id, + model_version=neuro_model_version, + vpc_config=self.vpc_config, + ) + return job_model.env + return None + + def _set_optimization_image_default( + self, create_optimization_job_args: Dict[str, Any] + ) -> Dict[str, Any]: + """Defaults the optimization image to the JumpStart deployment config default + + Args: + create_optimization_job_args (Dict[str, Any]): create optimization job request + + Returns: + Dict[str, Any]: create optimization job request with image uri default + """ + default_image = self._get_default_vllm_image(self.pysdk_model.init_kwargs["image_uri"]) + + # find the latest vLLM image version + for optimization_config in create_optimization_job_args.get("OptimizationConfigs"): + if optimization_config.get("ModelQuantizationConfig"): + model_quantization_config = optimization_config.get("ModelQuantizationConfig") + provided_image = model_quantization_config.get("Image") + if provided_image and self._get_latest_lmi_version_from_list( + default_image, provided_image + ): + default_image = provided_image + if optimization_config.get("ModelShardingConfig"): + model_sharding_config = optimization_config.get("ModelShardingConfig") + provided_image = model_sharding_config.get("Image") + if provided_image and self._get_latest_lmi_version_from_list( + default_image, provided_image + ): + default_image = provided_image + + # default to latest vLLM version + for optimization_config in create_optimization_job_args.get("OptimizationConfigs"): + if optimization_config.get("ModelQuantizationConfig") is not None: + optimization_config.get("ModelQuantizationConfig")["Image"] = default_image + if optimization_config.get("ModelShardingConfig") is not None: + optimization_config.get("ModelShardingConfig")["Image"] = default_image + + logger.info("Defaulting to %s image for optimization job", default_image) + + return create_optimization_job_args + + def _get_default_vllm_image(self, image: str) -> bool: + """Ensures the minimum working image version for vLLM enabled optimization techniques + + Args: + image (str): JumpStart provided default image + + Returns: + str: minimum working image version + """ + dlc_name, _ = image.split(":") + major_version_number, _, _ = self._parse_lmi_version(image) + + if major_version_number < self._parse_lmi_version(_JS_MINIMUM_VERSION_IMAGE)[0]: + minimum_version_default = _JS_MINIMUM_VERSION_IMAGE.format(dlc_name) + return minimum_version_default + return image + + def _get_latest_lmi_version_from_list(self, version: str, version_to_compare: str) -> bool: + """LMI version comparator + + Args: + version (str): current version + version_to_compare (str): version to compare to + + Returns: + bool: if version_to_compare larger or equal to version + """ + parse_lmi_version = self._parse_lmi_version(version) + parse_lmi_version_to_compare = self._parse_lmi_version(version_to_compare) + + # Check major version + if parse_lmi_version_to_compare[0] > parse_lmi_version[0]: + return True + # Check minor version + if parse_lmi_version_to_compare[0] == parse_lmi_version[0]: + if parse_lmi_version_to_compare[1] > parse_lmi_version[1]: + return True + if parse_lmi_version_to_compare[1] == parse_lmi_version[1]: + # Check patch version + if parse_lmi_version_to_compare[2] >= parse_lmi_version[2]: + return True + return False + return False + return False + + def _parse_lmi_version(self, image: str) -> Tuple[int, int, int]: + """Parse out LMI version + + Args: + image (str): image to parse version out of + + Returns: + Tuple[int, int, int]: LMI version split into major, minor, patch + """ + _, dlc_tag = image.split(":") + _, lmi_version, _ = dlc_tag.split("-") + major_version, minor_version, patch_version = lmi_version.split(".") + major_version_number = major_version[3:] + + return (int(major_version_number), int(minor_version), int(patch_version)) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 06b3d70aeb..3c19e4aa43 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -11,42 +11,62 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Holds the ModelBuilder class and the ModelServer enum.""" -from __future__ import absolute_import +from __future__ import absolute_import, annotations + +import importlib.util +import json import uuid from typing import Any, Type, List, Dict, Optional, Union from dataclasses import dataclass, field import logging import os +import re from pathlib import Path +from botocore.exceptions import ClientError +from sagemaker_core.main.resources import TrainingJob + +from sagemaker.transformer import Transformer +from sagemaker.async_inference import AsyncInferenceConfig +from sagemaker.batch_inference.batch_transform_inference_config import BatchTransformInferenceConfig +from sagemaker.compute_resource_requirements import ResourceRequirements +from sagemaker.enums import Tag, EndpointType +from sagemaker.estimator import Estimator +from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket from sagemaker.s3 import S3Downloader - from sagemaker import Session from sagemaker.model import Model +from sagemaker.jumpstart.model import JumpStartModel from sagemaker.base_predictor import PredictorBase -from sagemaker.djl_inference import defaults from sagemaker.serializers import NumpySerializer, TorchTensorSerializer from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.serve.builder.tf_serving_builder import TensorflowServing from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode from sagemaker.serve.mode.local_container_mode import LocalContainerMode +from sagemaker.serve.mode.in_process_mode import InProcessMode from sagemaker.serve.detector.pickler import save_pkl, save_xgboost from sagemaker.serve.builder.serve_settings import _ServeSettings from sagemaker.serve.builder.djl_builder import DJL +from sagemaker.serve.builder.tei_builder import TEI from sagemaker.serve.builder.tgi_builder import TGI from sagemaker.serve.builder.jumpstart_builder import JumpStart from sagemaker.serve.builder.transformers_builder import Transformers from sagemaker.predictor import Predictor from sagemaker.serve.model_format.mlflow.constants import ( MLFLOW_MODEL_PATH, + MLFLOW_TRACKING_ARN, + MLFLOW_RUN_ID_REGEX, + MLFLOW_REGISTRY_PATH_REGEX, + MODEL_PACKAGE_ARN_REGEX, MLFLOW_METADATA_FILE, MLFLOW_PIP_DEPENDENCY_FILE, ) from sagemaker.serve.model_format.mlflow.utils import ( _get_default_model_server_for_mlflow, - _mlflow_input_is_local_path, _download_s3_artifacts, _select_container_for_mlflow_model, _generate_mlflow_artifact_path, @@ -57,9 +77,23 @@ ) from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.spec.inference_base import CustomOrchestrator, AsyncCustomOrchestrator from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException -from sagemaker.serve.utils.predictors import _get_local_mode_predictor +from sagemaker.serve.utils.lineage_utils import _maintain_lineage_tracking_for_mlflow_model +from sagemaker.serve.utils.optimize_utils import ( + _generate_optimized_model, + _generate_model_source, + _extract_optimization_config_and_env, + _is_s3_uri, + _custom_speculative_decoding, + _extract_speculative_draft_model_provider, + _jumpstart_speculative_decoding, +) +from sagemaker.serve.utils.predictors import ( + _get_local_mode_predictor, + _get_in_process_mode_predictor, +) from sagemaker.serve.utils.hardware_detector import ( _get_gpu_info, _get_gpu_info_fallback, @@ -71,30 +105,43 @@ _get_model_base, ) from sagemaker.serve.model_server.torchserve.prepare import prepare_for_torchserve +from sagemaker.serve.model_server.smd.prepare import prepare_for_smd from sagemaker.serve.model_server.triton.triton_builder import Triton from sagemaker.serve.utils.telemetry_logger import _capture_telemetry -from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.utils.types import ModelServer, ModelHub from sagemaker.serve.validations.check_image_uri import is_1p_image_uri from sagemaker.serve.save_retrive.version_1_0_0.save.save_handler import SaveHandler from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import get_metadata from sagemaker.serve.validations.check_image_and_hardware_type import ( validate_image_uri_and_hardware, ) +from sagemaker.serverless import ServerlessInferenceConfig +from sagemaker.utils import Tags from sagemaker.workflow.entities import PipelineVariable -from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata - -logger = logging.getLogger(__name__) +from sagemaker.huggingface.llm_utils import ( + get_huggingface_model_metadata, + download_huggingface_model_metadata, +) +from sagemaker.serve.validations.optimization import _validate_optimization_configuration +from sagemaker.modules.train import ModelTrainer +from sagemaker.modules import logger -supported_model_server = { +# Any new server type should be added here +supported_model_servers = { ModelServer.TORCHSERVE, ModelServer.TRITON, ModelServer.DJL_SERVING, + ModelServer.TENSORFLOW_SERVING, + ModelServer.MMS, + ModelServer.TGI, + ModelServer.TEI, + ModelServer.SMD, } -# pylint: disable=attribute-defined-outside-init, disable=E1101 +# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901, disable=R1705 @dataclass -class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): +class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, TEI): """Class that builds a deployable model. Args: @@ -108,6 +155,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): * ``Mode.SAGEMAKER_ENDPOINT``: Launch on a SageMaker endpoint * ``Mode.LOCAL_CONTAINER``: Launch locally with a container + * ``Mode.IN_PROCESS``: Launch locally to a FastAPI server instead of using a container. shared_libs (List[str]): Any shared libraries you want to bring into the model packaging. dependencies (Optional[Dict[str, Any]): The dependencies of the model @@ -144,8 +192,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): The schema builder can be omitted for HuggingFace models with task types TextGeneration, TextClassification, and QuestionAnswering. Omitting SchemaBuilder is in beta for FillMask, and AutomaticSpeechRecognition use-cases. - model (Optional[Union[object, str]): Model object (with ``predict`` method to perform - inference) or a HuggingFace/JumpStart Model ID. Either ``model`` or ``inference_spec`` + model (Optional[Union[object, str, ModelTrainer, TrainingJob, Estimator]]): + Define object from which training artifacts can be extracted. + Either ``model`` or ``inference_spec`` is required for the model builder to build the artifact. inference_spec (InferenceSpec): The inference spec file with your customized ``invoke`` and ``load`` functions. @@ -165,13 +214,29 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): in order for model builder to build the artifacts correctly (according to the model server). Possible values for this argument are ``TORCHSERVE``, ``MMS``, ``TENSORFLOW_SERVING``, ``DJL_SERVING``, - ``TRITON``, and``TGI``. + ``TRITON``, ``TGI``, and ``TEI``. model_metadata (Optional[Dict[str, Any]): Dictionary used to override model metadata. Currently, ``HF_TASK`` is overridable for HuggingFace model. HF_TASK should be set for new models without task metadata in the Hub, adding unsupported task types will throw an exception. ``MLFLOW_MODEL_PATH`` is available for providing local path or s3 path to MLflow artifacts. However, ``MLFLOW_MODEL_PATH`` is experimental and is not - intended for production use at this moment. + intended for production use at this moment. ``CUSTOM_MODEL_PATH`` is available for + providing local path or s3 path to model artifacts. ``FINE_TUNING_MODEL_PATH`` is + available for providing s3 path to fine-tuned model artifacts. ``FINE_TUNING_JOB_NAME`` + is available for providing fine-tuned job name. Both ``FINE_TUNING_MODEL_PATH`` and + ``FINE_TUNING_JOB_NAME`` are mutually exclusive. + inference_component_name (Optional[str]): The name for an inference component + created from this ModelBuilder instance. This or ``resource_requirements`` must be set + to denote that this instance refers to an inference component. + modelbuilder_list: Optional[List[ModelBuilder]] = List of ModelBuilder objects which + can be built in bulk and subsequently deployed in bulk. Currently only supports + deployments for inference components. + resource_requirements: Optional[ResourceRequirements] = Defines the compute resources + allocated to run the model assigned to the inference component. This or + ``inference_component_name`` must be set to denote that this instance refers + to an inference component. If ``inference_component_name`` is set but this is not and a + JumpStart model ID is specified, pre-benchmarked deployment configs will attempt to be + retrieved for the model. """ model_path: Optional[str] = field( @@ -185,7 +250,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): default=None, metadata={"help": "Define sagemaker session for execution"} ) name: Optional[str] = field( - default="model-name-" + uuid.uuid1().hex, + default_factory=lambda: "model-name-" + uuid.uuid1().hex, metadata={"help": "Define the model name"}, ) mode: Optional[Mode] = field( @@ -232,14 +297,9 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): schema_builder: Optional[SchemaBuilder] = field( default=None, metadata={"help": "Defines the i/o schema of the model"} ) - model: Optional[Union[object, str]] = field( + model: Optional[Union[object, str, ModelTrainer, TrainingJob, Estimator]] = field( default=None, - metadata={ - "help": ( - 'Model object with "predict" method to perform inference ' - "or HuggingFace/JumpStart Model ID" - ) - }, + metadata={"help": "Define object from which training artifacts can be extracted"}, ) inference_spec: InferenceSpec = field( default=None, @@ -272,35 +332,28 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): default=None, metadata={ "help": "Define the model metadata to override, currently supports `HF_TASK`, " - "`MLFLOW_MODEL_PATH`. HF_TASK should be set for new models without task metadata in " - "the Hub, Adding unsupported task types will throw an exception" + "`MLFLOW_MODEL_PATH`, `FINE_TUNING_MODEL_PATH`, `FINE_TUNING_JOB_NAME`, and " + "`CUSTOM_MODEL_PATH`. HF_TASK should be set for new models without task metadata " + "in the Hub, Adding unsupported task types will throw an exception." + }, + ) + inference_component_name: Optional[str] = field( + default=None, + metadata={ + "help": "Defines the name for an Inference Component created from this ModelBuilder." + }, + ) + modelbuilder_list: Optional[List[ModelBuilder]] = field( + default=None, + metadata={"help": "Defines a list of ModelBuilder objects."}, + ) + resource_requirements: Optional[ResourceRequirements] = field( + default=None, + metadata={ + "help": "Defines the compute resources allocated to run the model assigned" + " to the inference component." }, ) - - def _build_validations(self): - """Placeholder docstring""" - # TODO: Beta validations - remove after the launch - if self.mode == Mode.IN_PROCESS: - raise ValueError("IN_PROCESS mode is not supported yet!") - - if self.inference_spec and self.model: - raise ValueError("Cannot have both the Model and Inference spec in the builder") - - if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None: - raise ValueError( - "Model_server must be set when non-first-party image_uri is set. " - + "Supported model servers: %s" % supported_model_server - ) - - # Set TorchServe as default model server - if not self.model_server: - self.model_server = ModelServer.TORCHSERVE - - if self.model_server not in supported_model_server: - raise ValueError( - "%s is not supported yet! Supported model servers: %s" - % (self.model_server, supported_model_server) - ) def _save_model_inference_spec(self): """Placeholder docstring""" @@ -375,8 +428,15 @@ def _get_serve_setting(self): sagemaker_session=self.sagemaker_session, ) - def _prepare_for_mode(self): - """Placeholder docstring""" + def _prepare_for_mode( + self, model_path: Optional[str] = None, should_upload_artifacts: Optional[bool] = False + ): + """Prepare this `Model` for serving. + + Args: + model_path (Optional[str]): Model path + should_upload_artifacts (Optional[bool]): Whether to upload artifacts to S3. + """ # TODO: move mode specific prepare steps under _model_builder_deploy_wrapper self.s3_upload_path = None if self.mode == Mode.SAGEMAKER_ENDPOINT: @@ -387,16 +447,17 @@ def _prepare_for_mode(self): self.s3_upload_path, env_vars_sagemaker = self.modes[ str(Mode.SAGEMAKER_ENDPOINT) ].prepare( - self.model_path, + (model_path or self.model_path), self.secret_key, self.serve_settings.s3_model_data_url, self.sagemaker_session, self.image_uri, - self.jumpstart if hasattr(self, "jumpstart") else False, + getattr(self, "model_hub", None) == ModelHub.JUMPSTART, + should_upload_artifacts=should_upload_artifacts, ) self.env_vars.update(env_vars_sagemaker) return self.s3_upload_path, env_vars_sagemaker - if self.mode == Mode.LOCAL_CONTAINER: + elif self.mode == Mode.LOCAL_CONTAINER: # init the LocalContainerMode object self.modes[str(Mode.LOCAL_CONTAINER)] = LocalContainerMode( inference_spec=self.inference_spec, @@ -408,9 +469,22 @@ def _prepare_for_mode(self): ) self.modes[str(Mode.LOCAL_CONTAINER)].prepare() return None + elif self.mode == Mode.IN_PROCESS: + # init the InProcessMode object + self.modes[str(Mode.IN_PROCESS)] = InProcessMode( + inference_spec=self.inference_spec, + model=self.model, + schema_builder=self.schema_builder, + session=self.sagemaker_session, + model_path=self.model_path, + env_vars=self.env_vars, + ) + self.modes[str(Mode.IN_PROCESS)].prepare() + return None raise ValueError( - "Please specify mode in: %s, %s" % (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT) + "Please specify mode in: %s, %s, %s" + % (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS) ) def _get_client_translators(self): @@ -425,7 +499,7 @@ def _get_client_translators(self): elif self.schema_builder: serializer = self.schema_builder.input_serializer else: - raise Exception("Cannot serialize") + raise Exception("Cannot serialize. Try providing a SchemaBuilder if not present.") deserializer = None if self.accept_type == "application/json": @@ -437,11 +511,13 @@ def _get_client_translators(self): elif self.schema_builder: deserializer = self.schema_builder.output_deserializer else: - raise Exception("Cannot deserialize") + raise Exception("Cannot deserialize. Try providing a SchemaBuilder if not present.") return serializer, deserializer - def _get_predictor(self, endpoint_name: str, sagemaker_session: Session) -> Predictor: + def _get_predictor( + self, endpoint_name: str, sagemaker_session: Session, component_name: Optional[str] = None + ) -> Predictor: """Placeholder docstring""" serializer, deserializer = self._get_client_translators() @@ -450,6 +526,7 @@ def _get_predictor(self, endpoint_name: str, sagemaker_session: Session) -> Pred sagemaker_session=sagemaker_session, serializer=serializer, deserializer=deserializer, + component_name=component_name, ) def _create_model(self): @@ -464,6 +541,7 @@ def _create_model(self): env=self.env_vars, sagemaker_session=self.sagemaker_session, predictor_cls=self._get_predictor, + name=self.name, ) # store the modes in the model so that we may @@ -471,6 +549,10 @@ def _create_model(self): self.pysdk_model.mode = self.mode self.pysdk_model.modes = self.modes self.pysdk_model.serve_settings = self.serve_settings + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session # dynamically generate a method to direct model.deploy() logic based on mode # unique method to models created via ModelBuilder() @@ -493,6 +575,13 @@ def _model_builder_register_wrapper(self, *args, **kwargs): self.pysdk_model.model_package_arn = new_model_package.model_package_arn new_model_package.deploy = self._model_builder_deploy_model_package_wrapper self.model_package = new_model_package + if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT: + _maintain_lineage_tracking_for_mlflow_model( + mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH], + s3_upload_path=self.s3_upload_path, + sagemaker_session=self.sagemaker_session, + tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN), + ) return new_model_package def _model_builder_deploy_model_package_wrapper(self, *args, **kwargs): @@ -507,6 +596,83 @@ def _model_builder_deploy_model_package_wrapper(self, *args, **kwargs): self.pysdk_model.model_package_arn = None return predictor + def _deploy_for_ic( + self, + *args, + ic_data: Dict[str, Any], + container_timeout_in_seconds: int = 300, + model_data_download_timeout: int = 3600, + instance_type: Optional[str] = None, + initial_instance_count: Optional[int] = None, + endpoint_name: Optional[str] = None, + **kwargs, + ) -> Predictor: + """Creates an Inference Component from a ModelBuilder.""" + ic_name = ic_data.get("Name", None) + model = ic_data.get("Model", None) + resource_requirements = ic_data.get("ResourceRequirements", {}) + + # Ensure resource requirements are set for non-JumpStart models + if not resource_requirements: + raise ValueError( + f"Cannot create/update inference component {ic_name} without resource requirements." + ) + + # Check if the Inference Component exists + if ic_name and self._does_ic_exist(ic_name=ic_name): + logger.info("Updating Inference Component %s as it already exists.", ic_name) + + # Create spec for updating the IC + startup_parameters = {} + if model_data_download_timeout is not None: + startup_parameters["ModelDataDownloadTimeoutInSeconds"] = ( + model_data_download_timeout + ) + if container_timeout_in_seconds is not None: + startup_parameters["ContainerStartupHealthCheckTimeoutInSeconds"] = ( + container_timeout_in_seconds + ) + compute_rr = resource_requirements.get_compute_resource_requirements() + inference_component_spec = { + "ModelName": self.name, + "StartupParameters": startup_parameters, + "ComputeResourceRequirements": compute_rr, + } + runtime_config = {"CopyCount": resource_requirements.copy_count} + response = self.sagemaker_session.update_inference_component( + inference_component_name=ic_name, + specification=inference_component_spec, + runtime_config=runtime_config, + ) + return Predictor(endpoint_name=response.get("EndpointName"), component_name=ic_name) + else: + kwargs.update( + { + "resources": resource_requirements, + "endpoint_type": EndpointType.INFERENCE_COMPONENT_BASED, + "inference_component_name": ic_name, + "endpoint_logging": False, + } + ) + return model.deploy( + *args, + container_startup_health_check_timeout=container_timeout_in_seconds, + initial_instance_count=initial_instance_count, + instance_type=instance_type, + mode=Mode.SAGEMAKER_ENDPOINT, + endpoint_name=endpoint_name, + **kwargs, + ) + + def _does_ic_exist(self, ic_name: str) -> bool: + """Returns true if an Inference Component exists with the given name.""" + try: + self.sagemaker_session.describe_inference_component(inference_component_name=ic_name) + return True + except ClientError as e: + msg = e.response["Error"]["Message"] + return "Could not find inference component" not in msg + @_capture_telemetry("torchserve.deploy") def _model_builder_deploy_wrapper( self, @@ -521,6 +687,18 @@ def _model_builder_deploy_wrapper( if mode and mode != self.mode: self._overwrite_mode_in_deploy(overwrite_mode=mode) + if self.mode == Mode.IN_PROCESS: + serializer, deserializer = self._get_client_translators() + + predictor = _get_in_process_mode_predictor( + self.modes[str(Mode.IN_PROCESS)], serializer, deserializer + ) + + self.modes[str(Mode.IN_PROCESS)].create_server( + predictor, + ) + return predictor + if self.mode == Mode.LOCAL_CONTAINER: serializer, deserializer = self._get_client_translators() predictor = _get_local_mode_predictor( @@ -534,14 +712,11 @@ def _model_builder_deploy_wrapper( self.image_uri, container_timeout_in_second, self.secret_key, predictor ) return predictor + if self.mode == Mode.SAGEMAKER_ENDPOINT: # Validate parameters - if not instance_type: - raise ValueError("Missing required parameter `instance_type`") - - if not initial_instance_count: - raise ValueError("Missing required parameter `initial_instance_count`") - + # Instance type and instance count parameter validation is done based on deployment type + # and will be done inside Model.deploy() if is_1p_image_uri(image_uri=self.image_uri): validate_image_uri_and_hardware( image_uri=self.image_uri, @@ -551,12 +726,29 @@ def _model_builder_deploy_wrapper( if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - return self._original_deploy( + + if "inference_component_name" not in kwargs and self.inference_component_name: + kwargs["inference_component_name"] = self.inference_component_name + + if "resources" not in kwargs and self.resource_requirements: + kwargs["resources"] = self.resource_requirements + + kwargs.pop("mode", None) + self.pysdk_model.role = kwargs.pop("role", self.pysdk_model.role) + predictor = self._original_deploy( *args, instance_type=instance_type, initial_instance_count=initial_instance_count, **kwargs, ) + if getattr(self, "_is_mlflow_model", False) and self.mode == Mode.SAGEMAKER_ENDPOINT: + _maintain_lineage_tracking_for_mlflow_model( + mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH], + s3_upload_path=self.s3_upload_path, + sagemaker_session=self.sagemaker_session, + tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN), + ) + return predictor def _overwrite_mode_in_deploy(self, overwrite_mode: str): """Mode overwritten by customer during model.deploy()""" @@ -570,10 +762,12 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str): s3_upload_path, env_vars_sagemaker = self._prepare_for_mode() self.pysdk_model.model_data = s3_upload_path self.pysdk_model.env.update(env_vars_sagemaker) - elif overwrite_mode == Mode.LOCAL_CONTAINER: self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER self._prepare_for_mode() + elif overwrite_mode == Mode.IN_PROCESS: + self.mode = self.pysdk_model.mode = Mode.IN_PROCESS + self._prepare_for_mode() else: raise ValueError("Mode %s is not supported!" % overwrite_mode) @@ -581,20 +775,39 @@ def _build_for_torchserve(self) -> Type[Model]: """Build the model for torchserve""" self._save_model_inference_spec() - self._auto_detect_container() + if self.mode != Mode.IN_PROCESS: + self._auto_detect_container() - self.secret_key = prepare_for_torchserve( - model_path=self.model_path, - shared_libs=self.shared_libs, - dependencies=self.dependencies, - session=self.sagemaker_session, - image_uri=self.image_uri, - inference_spec=self.inference_spec, - ) + self.secret_key = prepare_for_torchserve( + model_path=self.model_path, + shared_libs=self.shared_libs, + dependencies=self.dependencies, + session=self.sagemaker_session, + image_uri=self.image_uri, + inference_spec=self.inference_spec, + ) self._prepare_for_mode() + self.model = self._create_model() + return self.model + + def _build_for_smd(self) -> Type[Model]: + """Build the model for SageMaker Distribution""" + self._save_model_inference_spec() - return self._create_model() + if self.mode != Mode.IN_PROCESS: + self._auto_detect_container() + + self.secret_key = prepare_for_smd( + model_path=self.model_path, + shared_libs=self.shared_libs, + dependencies=self.dependencies, + inference_spec=self.inference_spec, + ) + + self._prepare_for_mode() + self.model = self._create_model() + return self.model def _user_agent_decorator(self, func): """Placeholder docstring""" @@ -608,11 +821,25 @@ def wrapper(*args, **kwargs): return wrapper - def _check_if_input_is_mlflow_model(self) -> bool: - """Checks whether an MLmodel file exists in the given directory. + def _handle_mlflow_input(self): + """Check whether an MLflow model is present and handle accordingly""" + self._is_mlflow_model = self._has_mlflow_arguments() + if not self._is_mlflow_model: + return + + mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH) + artifact_path = self._get_artifact_path(mlflow_model_path) + if not self._mlflow_metadata_exists(artifact_path): + return + + self._initialize_for_mlflow(artifact_path) + _validate_input_for_mlflow(self.model_server, self.env_vars.get("MLFLOW_MODEL_FLAVOR")) + + def _has_mlflow_arguments(self) -> bool: + """Check whether MLflow model arguments are present Returns: - bool: True if the MLmodel file exists, False otherwise. + bool: True if MLflow arguments are present, False otherwise. """ if self.inference_spec or self.model: logger.info( @@ -627,8 +854,8 @@ def _check_if_input_is_mlflow_model(self) -> bool: ) return False - path = self.model_metadata.get(MLFLOW_MODEL_PATH) - if not path: + mlflow_model_path = self.model_metadata.get(MLFLOW_MODEL_PATH) + if not mlflow_model_path: logger.info( "%s is not provided in ModelMetadata. ModelBuilder is not handling MLflow model " "input", @@ -636,7 +863,73 @@ def _check_if_input_is_mlflow_model(self) -> bool: ) return False - # Check for S3 path + return True + + def _get_artifact_path(self, mlflow_model_path: str) -> str: + """Retrieves the model artifact location given the Mlflow model input. + + Args: + mlflow_model_path (str): The MLflow model path input. + + Returns: + str: The path to the model artifact. + """ + if (is_run_id_type := re.match(MLFLOW_RUN_ID_REGEX, mlflow_model_path)) or re.match( + MLFLOW_REGISTRY_PATH_REGEX, mlflow_model_path + ): + mlflow_tracking_arn = self.model_metadata.get(MLFLOW_TRACKING_ARN) + if not mlflow_tracking_arn: + raise ValueError( + "%s is not provided in ModelMetadata or through set_tracking_arn " + "but MLflow model path was provided." % MLFLOW_TRACKING_ARN, + ) + + if not importlib.util.find_spec("sagemaker_mlflow"): + raise ImportError( + "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed" + ) + + import mlflow + + mlflow.set_tracking_uri(mlflow_tracking_arn) + if is_run_id_type: + _, run_id, model_path = mlflow_model_path.split("/", 2) + artifact_uri = mlflow.get_run(run_id).info.artifact_uri + if not artifact_uri.endswith("/"): + artifact_uri += "/" + return artifact_uri + model_path + + mlflow_client = mlflow.MlflowClient() + if not mlflow_model_path.endswith("/"): + mlflow_model_path += "/" + + if "@" in mlflow_model_path: + _, model_name_and_alias, artifact_uri = mlflow_model_path.split("/", 2) + model_name, model_alias = model_name_and_alias.split("@") + model_metadata = mlflow_client.get_model_version_by_alias(model_name, model_alias) + else: + _, model_name, model_version, artifact_uri = mlflow_model_path.split("/", 3) + model_metadata = mlflow_client.get_model_version(model_name, model_version) + + source = model_metadata.source + if not source.endswith("/"): + source += "/" + return source + artifact_uri + + if re.match(MODEL_PACKAGE_ARN_REGEX, mlflow_model_path): + model_package = self.sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=mlflow_model_path + ) + return model_package["SourceUri"] + + return mlflow_model_path + + def _mlflow_metadata_exists(self, path: str) -> bool: + """Checks whether an MLmodel file exists in the given directory. + + Returns: + bool: True if the MLmodel file exists, False otherwise. + """ if path.startswith("s3://"): s3_downloader = S3Downloader() if not path.endswith("/"): @@ -648,14 +941,18 @@ def _check_if_input_is_mlflow_model(self) -> bool: file_path = os.path.join(path, MLFLOW_METADATA_FILE) return os.path.isfile(file_path) - def _initialize_for_mlflow(self) -> None: - """Initialize mlflow model artifacts, image uri and model server.""" - mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH) - if not _mlflow_input_is_local_path(mlflow_path): - # TODO: extend to package arn, run id and etc. - _download_s3_artifacts(mlflow_path, self.model_path, self.sagemaker_session) + def _initialize_for_mlflow(self, artifact_path: str) -> None: + """Initialize mlflow model artifacts, image uri and model server. + + Args: + artifact_path (str): The path to the artifact store. + """ + if artifact_path.startswith("s3://"): + _download_s3_artifacts(artifact_path, self.model_path, self.sagemaker_session) + elif os.path.exists(artifact_path): + _copy_directory_contents(artifact_path, self.model_path) else: - _copy_directory_contents(mlflow_path, self.model_path) + raise ValueError("Invalid path: %s" % artifact_path) mlflow_model_metadata_path = _generate_mlflow_artifact_path( self.model_path, MLFLOW_METADATA_FILE ) @@ -678,11 +975,240 @@ def _initialize_for_mlflow(self) -> None: self.env_vars.update({"MLFLOW_MODEL_FLAVOR": f"{deployment_flavor}"}) self.dependencies.update({"requirements": mlflow_model_dependency_path}) + @_capture_telemetry("ModelBuilder.build_training_job") + def _collect_training_job_model_telemetry(self): + """Dummy method to collect telemetry for training job handshake""" + return + + @_capture_telemetry("ModelBuilder.build_model_trainer") + def _collect_model_trainer_model_telemetry(self): + """Dummy method to collect telemetry for model trainer handshake""" + return + + @_capture_telemetry("ModelBuilder.build_estimator") + def _collect_estimator_model_telemetry(self): + """Dummy method to collect telemetry for estimator handshake""" + return + + def build( + self, + mode: Type[Mode] = None, + role_arn: str = None, + sagemaker_session: Optional[Session] = None, + ) -> Union[ModelBuilder, Type[Model]]: + """Creates deployable ``Model`` instances with all provided ``ModelBuilder`` objects. + + Args: + mode (Type[Mode], optional): The mode. Defaults to ``None``. + role_arn (str, optional): The IAM role arn. Defaults to ``None``. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Union[ModelBuilder, Type[Model]]: A deployable ``ModelBuilder`` object if multiple + ``ModelBuilders`` were built, or a deployable ``Model`` object. + """ + if role_arn: + self.role_arn = role_arn + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() + + deployables = {} + + if not self.modelbuilder_list and not isinstance( + self.inference_spec, (CustomOrchestrator, AsyncCustomOrchestrator) + ): + self.serve_settings = self._get_serve_setting() + return self._build_single_modelbuilder( + mode=mode, + role_arn=self.role_arn, + sagemaker_session=sagemaker_session, + ) + + # Multi-ModelBuilder case: deploy + built_ic_models = [] + if self.modelbuilder_list: + logger.info("Detected ModelBuilders in modelbuilder_list.") + for mb in self.modelbuilder_list: + if mb.mode == Mode.IN_PROCESS or mb.mode == Mode.LOCAL_CONTAINER: + raise ValueError( + "Bulk ModelBuilder building is only supported for SageMaker Endpoint Mode." + ) + + if (not mb.resource_requirements and not mb.inference_component_name) and ( + not mb.inference_spec + or not isinstance( + mb.inference_spec, (CustomOrchestrator, AsyncCustomOrchestrator) + ) + ): + raise ValueError( + "Bulk ModelBuilder building is only supported for Inference Components " + + "and custom orchestrators." + ) + + for mb in self.modelbuilder_list: + # Custom orchestrator definition found in inference_spec + mb.serve_settings = mb._get_serve_setting() + # Build for Inference Component + logger.info("Building ModelBuilder %s.", mb.name) + # Get JS deployment configs if ResourceRequirements not set + + mb = mb._get_ic_resource_requirements(mb=mb) + + built_model = mb._build_single_modelbuilder( + role_arn=self.role_arn, sagemaker_session=self.sagemaker_session + ) + built_ic_models.append( + { + "Name": mb.inference_component_name, + "ResourceRequirements": mb.resource_requirements, + "Model": built_model, + } + ) + logger.info( + "=====================Build for %s complete.===================", + mb.model, + ) + deployables["InferenceComponents"] = built_ic_models + + if isinstance(self.inference_spec, (CustomOrchestrator, AsyncCustomOrchestrator)): + logger.info("Building custom orchestrator.") + if self.mode == Mode.IN_PROCESS or self.mode == Mode.LOCAL_CONTAINER: + raise ValueError( + "Custom orchestrator deployment is only supported for" + "SageMaker Endpoint Mode." + ) + self.serve_settings = self._get_serve_setting() + cpu_or_gpu_instance = self._get_processing_unit() + self.image_uri = self._get_smd_image_uri(processing_unit=cpu_or_gpu_instance) + self.model_server = ModelServer.SMD + built_orchestrator = self._build_single_modelbuilder( + mode=Mode.SAGEMAKER_ENDPOINT, + role_arn=role_arn, + sagemaker_session=sagemaker_session, + ) + if not self.resource_requirements: + logger.info( + "Custom orchestrator resource_requirements not found. " + "Building as a SageMaker Endpoint instead of Inference Component." + ) + deployables["CustomOrchestrator"] = { + "Mode": "Endpoint", + "Model": built_orchestrator, + } + else: + # Network isolation of ICs on an endpoint must be consistent + if built_ic_models: + if ( + self.dependencies["auto"] + or "requirements" in self.dependencies + or "custom" in self.dependencies + ): + logger.warning( + "Custom orchestrator network isolation must be False when dependencies " + "are specified or using autocapture. To enable network isolation, " + "package all dependencies in the container or model artifacts " + "ahead of time." + ) + built_orchestrator._enable_network_isolation = False + for model in built_ic_models: + model["Model"]._enable_network_isolation = False + deployables["CustomOrchestrator"] = { + "Name": self.inference_component_name, + "Mode": "InferenceComponent", + "ResourceRequirements": self.resource_requirements, + "Model": built_orchestrator, + } + + logger.info( + "=====================Custom orchestrator build complete.===================", + ) + + self._deployables = deployables + return self + + def _get_processing_unit(self): + """Detects if the resource requirements are intended for a CPU or GPU instance.""" + # Assume custom orchestrator will be deployed as an endpoint to a CPU instance + if not self.resource_requirements or not self.resource_requirements.num_accelerators: + return "cpu" + for ic in self.modelbuilder_list or []: + if ic.resource_requirements.num_accelerators > 0: + return "gpu" + if self.resource_requirements.num_accelerators > 0: + return "gpu" + + return "cpu" + + def _get_ic_resource_requirements(self, mb: ModelBuilder = None) -> ModelBuilder: + """Attempts fetching pre-benchmarked resource requirements for the MB from JumpStart.""" + if mb._is_jumpstart_model_id() and not mb.resource_requirements: + js_model = JumpStartModel(model_id=mb.model) + deployment_configs = js_model.list_deployment_configs() + if not deployment_configs: + raise ValueError( + "No resource requirements were provided for Inference Component " + f"{mb.inference_component_name} and no default deployment" + " configs were found in JumpStart." + ) + compute_requirements = ( + deployment_configs[0].get("DeploymentArgs").get("ComputeResourceRequirements") + ) + logger.info("Retrieved pre-benchmarked deployment configurations from JumpStart.") + mb.resource_requirements = ResourceRequirements( + requests={ + "memory": compute_requirements["MinMemoryRequiredInMb"], + "num_accelerators": compute_requirements.get( + "NumberOfAcceleratorDevicesRequired", None + ), + "copies": 1, + "num_cpus": compute_requirements.get("NumberOfCpuCoresRequired", None), + }, + limits={"memory": compute_requirements.get("MaxMemoryRequiredInMb", None)}, + ) + + return mb + + @_capture_telemetry("build_custom_orchestrator") + def _get_smd_image_uri(self, processing_unit: str = None) -> str: + """Gets the SMD Inference Image URI. + + Returns: + str: SMD Inference Image URI. + """ + from sagemaker import image_uris + import sys + + self.sagemaker_session = self.sagemaker_session or Session() + from packaging.version import Version + + formatted_py_version = f"py{sys.version_info.major}{sys.version_info.minor}" + if Version(f"{sys.version_info.major}{sys.version_info.minor}") < Version("3.12"): + raise ValueError( + f"Found Python version {formatted_py_version} but" + f"Custom orchestrator deployment requires Python version >= 3.12." + ) + + INSTANCE_TYPES = {"cpu": "ml.c5.xlarge", "gpu": "ml.g5.4xlarge"} + + logger.info("Finding SMD inference image URI for a %s instance.", processing_unit) + + smd_uri = image_uris.retrieve( + framework="sagemaker-distribution", + image_scope="inference", + instance_type=INSTANCE_TYPES[processing_unit], + region=self.sagemaker_session.boto_region_name, + ) + logger.info("Found compatible image %s", smd_uri) + return smd_uri + # Model Builder is a class to build the model for deployment. - # It supports two modes of deployment + # It supports three modes of deployment # 1/ SageMaker Endpoint # 2/ Local launch with container - def build( # pylint: disable=R0911 + # 3/ In process mode with Transformers server in beta release + @_capture_telemetry("ModelBuilder.build") + def _build_single_modelbuilder( # pylint: disable=R0911 self, mode: Type[Mode] = None, role_arn: str = None, @@ -700,13 +1226,39 @@ def build( # pylint: disable=R0911 Returns: Type[Model]: A deployable ``Model`` object. """ + self.modes = dict() if mode: self.mode = mode if role_arn: self.role_arn = role_arn - self.sagemaker_session = sagemaker_session or Session() + + self.serve_settings = self._get_serve_setting() + + if isinstance(self.model, TrainingJob): + self.model_path = self.model.model_artifacts.s3_model_artifacts + self.model = None + self._collect_training_job_model_telemetry() + elif isinstance(self.model, ModelTrainer): + self.model_path = self.model._latest_training_job.model_artifacts.s3_model_artifacts + self.model = None + self._collect_model_trainer_model_telemetry() + elif isinstance(self.model, Estimator): + self.model_path = self.model.output_path + self.model = None + self._collect_estimator_model_telemetry() + + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() + + self.sagemaker_session.settings._local_download_dir = self.model_path + + # DJL expects `HF_TOKEN` key. This allows backward compatibility + # until we deprecate HUGGING_FACE_HUB_TOKEN. + if self.env_vars.get("HUGGING_FACE_HUB_TOKEN") and not self.env_vars.get("HF_TOKEN"): + self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + elif self.env_vars.get("HF_TOKEN") and not self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): + self.env_vars["HUGGING_FACE_HUB_TOKEN"] = self.env_vars.get("HF_TOKEN") self.sagemaker_session.settings._local_download_dir = self.model_path @@ -718,25 +1270,45 @@ def build( # pylint: disable=R0911 self.sagemaker_session.sagemaker_client._user_agent_creator.to_string ) - self.serve_settings = self._get_serve_setting() - self._is_custom_image_uri = self.image_uri is not None - self._is_mlflow_model = self._check_if_input_is_mlflow_model() - if self._is_mlflow_model: - logger.warning( - "Support of MLflow format models is experimental and is not intended" - " for production at this moment." - ) - self._initialize_for_mlflow() - _validate_input_for_mlflow(self.model_server) + + self._handle_mlflow_input() + + self._build_validations() + + if ( + not (isinstance(self.model, str) and self._is_jumpstart_model_id()) + ) and self.model_server: + self.built_model = self._build_for_model_server() + return self.built_model if isinstance(self.model, str): model_task = None + + if self._is_jumpstart_model_id(): + if self.mode == Mode.IN_PROCESS: + raise ValueError( + f"{self.mode} is not supported for Jumpstart models. " + "Please use LOCAL_CONTAINER mode to deploy a Jumpstart model" + " on your local machine." + ) + self.model_hub = ModelHub.JUMPSTART + logger.debug("Building for Jumpstart model Id...") + self.built_model = self._build_for_jumpstart() + return self.built_model + + if self.mode != Mode.IN_PROCESS: + if self._use_jumpstart_equivalent(): + self.model_hub = ModelHub.JUMPSTART + logger.debug("Building for Jumpstart equiavalent model Id...") + self.built_model = self._build_for_jumpstart() + return self.built_model + self.model_hub = ModelHub.HUGGINGFACE + if self.model_metadata: model_task = self.model_metadata.get("HF_TASK") - if self._is_jumpstart_model_id(): - return self._build_for_jumpstart() - if self._is_djl(): # pylint: disable=R1705 + + if self._is_djl(): return self._build_for_djl() else: hf_model_md = get_huggingface_model_metadata( @@ -747,19 +1319,52 @@ def build( # pylint: disable=R0911 model_task = hf_model_md.get("pipeline_tag") if self.schema_builder is None and model_task is not None: self._hf_schema_builder_init(model_task) - if model_task == "text-generation": # pylint: disable=R1705 - return self._build_for_tgi() + if model_task == "text-generation": + self.built_model = self._build_for_tgi() + return self.built_model + if model_task in ["sentence-similarity", "feature-extraction"]: + self.built_model = self._build_for_tei() + return self.built_model elif self._can_fit_on_single_gpu(): - return self._build_for_transformers() - elif ( - self.model in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES - or self.model in defaults.FASTER_TRANSFORMER_RECOMMENDED_ARCHITECTURES - ): - return self._build_for_djl() + self.built_model = self._build_for_transformers() + return self.built_model else: - return self._build_for_transformers() + self.built_model = self._build_for_transformers() + return self.built_model - self._build_validations() + # Set TorchServe as default model server + if not self.model_server: + self.model_server = ModelServer.TORCHSERVE + self.built_model = self._build_for_torchserve() + return self.built_model + + raise ValueError("%s model server is not supported" % self.model_server) + + def _build_validations(self): + """Validations needed for model server overrides, or auto-detection or fallback""" + if self.inference_spec and self.model: + raise ValueError("Can only set one of the following: model, inference_spec.") + + if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None: + raise ValueError( + "Model_server must be set when non-first-party image_uri is set. " + + "Supported model servers: %s" % supported_model_servers + ) + + def _build_for_model_server(self): # pylint: disable=R0911, R1710 + """Model server overrides""" + if self.model_server not in supported_model_servers: + raise ValueError( + "%s is not supported yet! Supported model servers: %s" + % (self.model_server, supported_model_servers) + ) + + mlflow_path = None + if self.model_metadata: + mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH) + + if not self.model and not mlflow_path and not self.inference_spec: + raise ValueError("Missing required parameter `model` or 'ml_flow' path or inf_spec") if self.model_server == ModelServer.TORCHSERVE: return self._build_for_torchserve() @@ -767,8 +1372,25 @@ def build( # pylint: disable=R0911 if self.model_server == ModelServer.TRITON: return self._build_for_triton() - raise ValueError("%s model server is not supported" % self.model_server) + if self.model_server == ModelServer.TENSORFLOW_SERVING: + return self._build_for_tensorflow_serving() + if self.model_server == ModelServer.DJL_SERVING: + return self._build_for_djl() + + if self.model_server == ModelServer.TEI: + return self._build_for_tei() + + if self.model_server == ModelServer.TGI: + return self._build_for_tgi() + + if self.model_server == ModelServer.MMS: + return self._build_for_transformers() + + if self.model_server == ModelServer.SMD: + return self._build_for_smd() + + @_capture_telemetry("ModelBuilder.save") def save( self, save_path: Optional[str] = None, @@ -781,8 +1403,15 @@ def save( This function is available for models served by DJL serving. Args: - save_path (Optional[str]): The path where you want to save resources. - s3_path (Optional[str]): The path where you want to upload resources. + save_path (Optional[str]): The path where you want to save resources. Defaults to + ``None``. + s3_path (Optional[str]): The path where you want to upload resources. Defaults to + ``None``. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. Defaults to + ``None``. + role_arn (Optional[str]): The IAM role arn. Defaults to ``None``. """ self.sagemaker_session = sagemaker_session or Session() @@ -813,6 +1442,19 @@ def validate(self, model_dir: str) -> Type[bool]: return get_metadata(model_dir) + def set_tracking_arn(self, arn: str): + """Set tracking server ARN""" + # TODO: support native MLflow URIs + if importlib.util.find_spec("sagemaker_mlflow"): + import mlflow + + mlflow.set_tracking_uri(arn) + self.model_metadata[MLFLOW_TRACKING_ARN] = arn + else: + raise ImportError( + "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed" + ) + def _hf_schema_builder_init(self, model_task: str): """Initialize the schema builder for the given HF_TASK @@ -892,3 +1534,690 @@ def _try_fetch_gpu_info(self): raise ValueError( f"Unable to determine single GPU size for instance: [{self.instance_type}]" ) + + def optimize( + self, + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + role_arn: Optional[str] = None, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + accept_eula: Optional[bool] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = 36000, + sagemaker_session: Optional[Session] = None, + ) -> Model: + """Create an optimized deployable ``Model`` instance with ``ModelBuilder``. + + Args: + output_path (str): Specifies where to store the compiled/quantized model. + instance_type (str): Target deployment instance type that the model is optimized for. + role_arn (Optional[str]): Execution role arn. Defaults to ``None``. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + 36000 seconds. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Model: A deployable ``Model`` object. + """ + + # need to get telemetry_opt_out info before telemetry decorator is called + self.serve_settings = self._get_serve_setting() + + return self._model_builder_optimize_wrapper( + output_path=output_path, + instance_type=instance_type, + role_arn=role_arn, + tags=tags, + job_name=job_name, + accept_eula=accept_eula, + quantization_config=quantization_config, + compilation_config=compilation_config, + speculative_decoding_config=speculative_decoding_config, + sharding_config=sharding_config, + env_vars=env_vars, + vpc_config=vpc_config, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + sagemaker_session=sagemaker_session, + ) + + @_capture_telemetry("optimize") + def _model_builder_optimize_wrapper( + self, + output_path: Optional[str] = None, + instance_type: Optional[str] = None, + role_arn: Optional[str] = None, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + accept_eula: Optional[bool] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = 36000, + sagemaker_session: Optional[Session] = None, + ) -> Model: + """Runs a model optimization job. + + Args: + output_path (str): Specifies where to store the compiled/quantized model. + instance_type (str): Target deployment instance type that the model is optimized for. + role_arn (Optional[str]): Execution role arn. Defaults to ``None``. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + 36000 seconds. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Model: A deployable ``Model`` object. + """ + if ( + hasattr(self, "enable_network_isolation") + and self.enable_network_isolation + and sharding_config + ): + raise ValueError( + "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " + "Loading of model requires network access." + ) + + # TODO: ideally these dictionaries need to be sagemaker_core shapes + # TODO: for organization, abstract all validation behind this fn + _validate_optimization_configuration( + is_jumpstart=self._is_jumpstart_model_id(), + instance_type=instance_type, + quantization_config=quantization_config, + compilation_config=compilation_config, + sharding_config=sharding_config, + speculative_decoding_config=speculative_decoding_config, + ) + + self.is_compiled = compilation_config is not None + self.is_quantized = quantization_config is not None + self.speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider( + speculative_decoding_config + ) + + if self.mode != Mode.SAGEMAKER_ENDPOINT: + raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.") + + if sharding_config and ( + quantization_config or compilation_config or speculative_decoding_config + ): + raise ValueError( + ( + "Sharding config is mutually exclusive " + "and cannot be combined with any other optimization." + ) + ) + + if sharding_config: + has_tensor_parallel_degree_in_env_vars = ( + env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars + ) + has_tensor_parallel_degree_in_overrides = ( + sharding_config + and sharding_config.get("OverrideEnvironment") + and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config.get("OverrideEnvironment") + ) + if ( + not has_tensor_parallel_degree_in_env_vars + and not has_tensor_parallel_degree_in_overrides + ): + raise ValueError( + ( + "OPTION_TENSOR_PARALLEL_DEGREE is a required " + "environment variable with sharding config." + ) + ) + + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() + self.instance_type = instance_type or self.instance_type + self.role_arn = role_arn or self.role_arn + + job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}" + if self._is_jumpstart_model_id(): + self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) + if self.pysdk_model: + self.pysdk_model.set_deployment_config( + instance_type=instance_type, config_name="lmi" + ) + input_args = self._optimize_for_jumpstart( + output_path=output_path, + instance_type=instance_type, + tags=tags, + job_name=job_name, + accept_eula=accept_eula, + quantization_config=quantization_config, + compilation_config=compilation_config, + speculative_decoding_config=speculative_decoding_config, + sharding_config=sharding_config, + env_vars=env_vars, + vpc_config=vpc_config, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + ) + else: + if self.model_server != ModelServer.DJL_SERVING: + logger.info("Overriding model server to DJL_SERVING.") + self.model_server = ModelServer.DJL_SERVING + + self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) + input_args = self._optimize_for_hf( + output_path=output_path, + tags=tags, + job_name=job_name, + quantization_config=quantization_config, + compilation_config=compilation_config, + speculative_decoding_config=speculative_decoding_config, + sharding_config=sharding_config, + env_vars=env_vars, + vpc_config=vpc_config, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + ) + + if sharding_config: + self.pysdk_model._is_sharded_model = True + + if input_args: + optimization_instance_type = input_args["DeploymentInstanceType"] + + # Compilation using TRTLLM and Llama-3.1 is currently not supported. + # TRTLLM is used by Neo if the following are provided: + # 1) a GPU instance type + # 2) compilation config + gpu_instance_families = ["g5", "g6", "p4d", "p4de", "p5"] + is_gpu_instance = optimization_instance_type and any( + gpu_instance_family in optimization_instance_type + for gpu_instance_family in gpu_instance_families + ) + + # HF Model ID format = "meta-llama/Meta-Llama-3.1-8B" + # JS Model ID format = "meta-textgeneration-llama-3-1-8b" + is_llama_3_plus = self.model and bool( + re.search(r"llama-3[\.\-][1-9]\d*", self.model.lower()) + ) + + if is_gpu_instance and self.model and self.is_compiled: + if is_llama_3_plus: + raise ValueError( + "Compilation is not supported for models greater " + "than Llama-3.0 with a GPU instance." + ) + if speculative_decoding_config: + raise ValueError( + "Compilation is not supported with speculative decoding with " + "a GPU instance." + ) + + self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args) + job_status = self.sagemaker_session.wait_for_optimization_job(job_name) + return _generate_optimized_model(self.pysdk_model, job_status) + + self.pysdk_model.remove_tag_with_key(Tag.OPTIMIZATION_JOB_NAME) + if not speculative_decoding_config: + self.pysdk_model.remove_tag_with_key(Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER) + + return self.pysdk_model + + def _optimize_for_hf( + self, + output_path: str, + tags: Optional[Tags] = None, + job_name: Optional[str] = None, + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, + env_vars: Optional[Dict] = None, + vpc_config: Optional[Dict] = None, + kms_key: Optional[str] = None, + max_runtime_in_sec: Optional[int] = None, + ) -> Optional[Dict[str, Any]]: + """Runs a model optimization job. + + Args: + output_path (str): Specifies where to store the compiled/quantized model. + tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``. + job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``. + quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``. + compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. + speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. + Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` + env_vars (Optional[Dict]): Additional environment variables to run the optimization + container. Defaults to ``None``. + vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. + kms_key (Optional[str]): KMS key ARN used to encrypt the model artifacts when uploading + to S3. Defaults to ``None``. + max_runtime_in_sec (Optional[int]): Maximum job execution time in seconds. Defaults to + ``None``. + + Returns: + Optional[Dict[str, Any]]: Model optimization job input arguments. + """ + if speculative_decoding_config: + if speculative_decoding_config.get("ModelProvider", "").lower() == "jumpstart": + _jumpstart_speculative_decoding( + model=self.pysdk_model, + speculative_decoding_config=speculative_decoding_config, + sagemaker_session=self.sagemaker_session, + ) + else: + self.pysdk_model = _custom_speculative_decoding( + self.pysdk_model, speculative_decoding_config, False + ) + + if quantization_config or compilation_config or sharding_config: + create_optimization_job_args = { + "OptimizationJobName": job_name, + "DeploymentInstanceType": self.instance_type, + "RoleArn": self.role_arn, + } + + if env_vars: + self.pysdk_model.env.update(env_vars) + create_optimization_job_args["OptimizationEnvironment"] = env_vars + + self._optimize_prepare_for_hf() + model_source = _generate_model_source(self.pysdk_model.model_data, False) + create_optimization_job_args["ModelSource"] = model_source + + ( + optimization_config, + quantization_override_env, + compilation_override_env, + sharding_override_env, + ) = _extract_optimization_config_and_env( + quantization_config, compilation_config, sharding_config + ) + create_optimization_job_args["OptimizationConfigs"] = [ + {k: v} for k, v in optimization_config.items() + ] + self.pysdk_model.env.update( + { + **(quantization_override_env or {}), + **(compilation_override_env or {}), + **(sharding_override_env or {}), + } + ) + + output_config = {"S3OutputLocation": output_path} + if kms_key: + output_config["KmsKeyId"] = kms_key + create_optimization_job_args["OutputConfig"] = output_config + + if max_runtime_in_sec: + create_optimization_job_args["StoppingCondition"] = { + "MaxRuntimeInSeconds": max_runtime_in_sec + } + if tags: + create_optimization_job_args["Tags"] = tags + if vpc_config: + create_optimization_job_args["VpcConfig"] = vpc_config + + # HF_MODEL_ID needs not to be present, otherwise, + # HF model artifacts will be re-downloaded during deployment + if "HF_MODEL_ID" in self.pysdk_model.env: + del self.pysdk_model.env["HF_MODEL_ID"] + + return create_optimization_job_args + return None + + def _optimize_prepare_for_hf(self): + """Prepare huggingface model data for optimization.""" + custom_model_path: str = ( + self.model_metadata.get("CUSTOM_MODEL_PATH") if self.model_metadata else None + ) + if _is_s3_uri(custom_model_path): + # Remove slash by the end of s3 uri, as it may lead to / subfolder during upload. + custom_model_path = ( + custom_model_path[:-1] if custom_model_path.endswith("/") else custom_model_path + ) + else: + if not custom_model_path: + custom_model_path = f"/tmp/sagemaker/model-builder/{self.model}" + download_huggingface_model_metadata( + self.model, + os.path.join(custom_model_path, "code"), + self.env_vars.get("HUGGING_FACE_HUB_TOKEN"), + ) + + self.pysdk_model.model_data, env = self._prepare_for_mode( + model_path=custom_model_path, + should_upload_artifacts=True, + ) + self.pysdk_model.env.update(env) + + @_capture_telemetry("ModelBuilder.deploy") + def deploy( + self, + endpoint_name: str = None, + container_timeout_in_second: int = 300, + instance_type: str = None, + initial_instance_count: Optional[int] = 1, + inference_config: Optional[ + Union[ + ServerlessInferenceConfig, + AsyncInferenceConfig, + BatchTransformInferenceConfig, + ResourceRequirements, + ] + ] = None, + update_endpoint: Optional[bool] = False, + custom_orchestrator_instance_type: str = None, + custom_orchestrator_initial_instance_count: int = None, + **kwargs, + ) -> Union[Predictor, Transformer, List[Predictor]]: + """Deploys the built Model. + + Depending on the type of config provided, this function will call deployment accordingly. + Args: + endpoint_name (str): Name of the endpoint to deploy. + The supplied base name is used as a prefix and + a unique ID is appended to guarantee uniqueness. + initial_instance_count (int): Number of instances to deploy. + inference_config (Optional[Union[ServerlessInferenceConfig, + AsyncInferenceConfig, BatchTransformInferenceConfig, ResourceRequirements]]) : + Additional Config for different deployment types such as + serverless, async, batch and multi-model/container + update_endpoint (Optional[bool]): + Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint + and delete resources corresponding to the previous EndpointConfig. Default: False + Note: Currently this is supported for single model endpoints + Returns: + Transformer for Batch Deployments + Predictors for all others + """ + if not hasattr(self, "built_model") and not hasattr(self, "_deployables"): + raise ValueError("Model needs to be built before deploying") + + if not hasattr(self, "_deployables"): + if not inference_config: # Real-time Deployment + return self.built_model.deploy( + instance_type=self.instance_type, + initial_instance_count=initial_instance_count, + endpoint_name=endpoint_name, + update_endpoint=update_endpoint, + ) + + if isinstance(inference_config, ServerlessInferenceConfig): + return self.built_model.deploy( + serverless_inference_config=inference_config, + endpoint_name=endpoint_name, + update_endpoint=update_endpoint, + ) + + if isinstance(inference_config, AsyncInferenceConfig): + return self.built_model.deploy( + instance_type=self.instance_type, + initial_instance_count=initial_instance_count, + async_inference_config=inference_config, + endpoint_name=endpoint_name, + update_endpoint=update_endpoint, + ) + + if isinstance(inference_config, BatchTransformInferenceConfig): + transformer = self.built_model.transformer( + instance_type=inference_config.instance_type, + output_path=inference_config.output_path, + instance_count=inference_config.instance_count, + ) + return transformer + + if isinstance(inference_config, ResourceRequirements): + if update_endpoint: + raise ValueError( + "Currently update_endpoint is supported for single model endpoints" + ) + # Multi Model and MultiContainer endpoints with Inference Component + return self.built_model.deploy( + instance_type=self.instance_type, + mode=Mode.SAGEMAKER_ENDPOINT, + endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, + resources=inference_config, + initial_instance_count=initial_instance_count, + role=self.role_arn, + update_endpoint=update_endpoint, + ) + + raise ValueError("Deployment Options not supported") + + # Iterate through deployables for a custom orchestrator deployment. + # Create all Inference Components first before deploying custom orchestrator if present. + predictors = [] + for inference_component in self._deployables.get("InferenceComponents", []): + predictors.append( + self._deploy_for_ic( + ic_data=inference_component, + container_timeout_in_seconds=container_timeout_in_second, + instance_type=instance_type, + initial_instance_count=initial_instance_count, + endpoint_name=endpoint_name, + **kwargs, + ) + ) + if self._deployables.get("CustomOrchestrator", None): + custom_orchestrator = self._deployables.get("CustomOrchestrator") + if not custom_orchestrator_instance_type and not instance_type: + logger.warning( + "Deploying custom orchestrator as an endpoint but no instance type was " + "set. Defaulting to `ml.c5.xlarge`." + ) + custom_orchestrator_instance_type = "ml.c5.xlarge" + custom_orchestrator_initial_instance_count = 1 + if custom_orchestrator["Mode"] == "Endpoint": + logger.info( + "Deploying custom orchestrator on instance type %s.", + custom_orchestrator_instance_type, + ) + predictors.append( + custom_orchestrator["Model"].deploy( + instance_type=custom_orchestrator_instance_type, + initial_instance_count=custom_orchestrator_initial_instance_count, + **kwargs, + ) + ) + elif custom_orchestrator["Mode"] == "InferenceComponent": + logger.info( + "Deploying custom orchestrator as an inference component " + f"to endpoint {endpoint_name}" + ) + predictors.append( + self._deploy_for_ic( + ic_data=custom_orchestrator, + container_timeout_in_seconds=container_timeout_in_second, + instance_type=custom_orchestrator_instance_type or instance_type, + initial_instance_count=custom_orchestrator_initial_instance_count + or initial_instance_count, + endpoint_name=endpoint_name, + **kwargs, + ) + ) + + return predictors + + def display_benchmark_metrics(self, **kwargs): + """Display Markdown Benchmark Metrics for deployment configs.""" + if not isinstance(self.model, str): + raise ValueError("Benchmarking is only supported for JumpStart or HuggingFace models") + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): + return super().display_benchmark_metrics(**kwargs) + else: + raise ValueError("This model does not have benchmark metrics yet") + + def get_deployment_config(self) -> Optional[Dict[str, Any]]: + """Gets the deployment config to apply to the model. + + Returns: + Optional[Dict[str, Any]]: Deployment config to apply to this model. + """ + if not isinstance(self.model, str): + raise ValueError( + "Deployment config is only supported for JumpStart or HuggingFace models" + ) + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): + return super().get_deployment_config() + else: + raise ValueError("This model does not have any deployment config yet") + + def list_deployment_configs(self) -> List[Dict[str, Any]]: + """List deployment configs for the model in the current region. + + Returns: + List[Dict[str, Any]]: A list of deployment configs. + """ + if not isinstance(self.model, str): + raise ValueError( + "Deployment config is only supported for JumpStart or HuggingFace models" + ) + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): + return super().list_deployment_configs() + else: + raise ValueError("This model does not have any deployment config yet") + + def set_deployment_config(self, config_name: str, instance_type: str) -> None: + """Sets the deployment config to apply to the model. + + Args: + config_name (str): + The name of the deployment config to apply to the model. + Call list_deployment_configs to see the list of config names. + instance_type (str): + The instance_type that the model will use after setting + the config. + """ + if not isinstance(self.model, str): + raise ValueError( + "Deployment config is only supported for JumpStart or HuggingFace models" + ) + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): + logger.warning( + "If there are existing deployment configurations, " + "they will be overwritten by the config %s", + config_name, + ) + return super().set_deployment_config(config_name, instance_type) + else: + raise ValueError(f"The deployment config {config_name} cannot be set on this model") + + def _use_jumpstart_equivalent(self): + """Check if the HuggingFace model has a JumpStart equivalent. + + Replace it with the equivalent if there's one + """ + # Do not use the equivalent JS model if image_uri or env_vars is provided + if self.image_uri or self.env_vars: + return False + if not hasattr(self, "_has_jumpstart_equivalent"): + self._jumpstart_mapping = self._retrieve_hugging_face_model_mapping() + self._has_jumpstart_equivalent = self.model in self._jumpstart_mapping + if self._has_jumpstart_equivalent: + # Use schema builder from HF model metadata + if not self.schema_builder: + model_task = None + if self.model_metadata: + model_task = self.model_metadata.get("HF_TASK") + hf_model_md = get_huggingface_model_metadata(self.model) + if not model_task: + model_task = hf_model_md.get("pipeline_tag") + if model_task: + self._hf_schema_builder_init(model_task) + + huggingface_model_id = self.model + jumpstart_model_id = self._jumpstart_mapping[huggingface_model_id]["jumpstart-model-id"] + self.model = jumpstart_model_id + merged_date = self._jumpstart_mapping[huggingface_model_id].get("merged-at") + self._build_for_jumpstart() + compare_model_diff_message = ( + "If you want to identify the differences between the two, " + "please use model_uris.retrieve() to retrieve the model " + "artifact S3 URI and compare them." + ) + logger.warning( # pylint: disable=logging-fstring-interpolation + "Please note that for this model we are using the JumpStart's " + f'local copy "{jumpstart_model_id}" ' + f'of the HuggingFace model "{huggingface_model_id}" you chose. ' + "We strive to keep our local copy synced with the HF model hub closely. " + "This model was synced " + f"{f'on {merged_date}' if merged_date else 'before 11/04/2024'}. " + f"{compare_model_diff_message if not self._is_gated_model() else ''}" + ) + return True + return False + + def _retrieve_hugging_face_model_mapping(self): + """Retrieve the HuggingFace/JumpStart model mapping and preprocess it.""" + converted_mapping = {} + region = self.sagemaker_session.boto_region_name + try: + mapping_json_object = JumpStartS3PayloadAccessor.get_object_cached( + bucket=get_jumpstart_content_bucket(region), + key="hf_model_id_map_cache.json", + region=region, + s3_client=self.sagemaker_session.s3_client, + ) + mapping = json.loads(mapping_json_object) + except Exception: # pylint: disable=broad-except + return converted_mapping + + for k, v in mapping.items(): + converted_mapping[v["hf-model-id"]] = { + "jumpstart-model-id": k, + "jumpstart-model-version": v["jumpstart-model-version"], + "merged-at": v.get("merged-at"), + "hf-model-repo-sha": v.get("hf-model-repo-sha"), + } + return converted_mapping diff --git a/src/sagemaker/serve/builder/requirements_manager.py b/src/sagemaker/serve/builder/requirements_manager.py new file mode 100644 index 0000000000..a8b41dba40 --- /dev/null +++ b/src/sagemaker/serve/builder/requirements_manager.py @@ -0,0 +1,100 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Requirements Manager class to pull in client dependencies from a .txt or .yml file""" +from __future__ import absolute_import +import logging +import os +import subprocess + +from typing import Optional + +logger = logging.getLogger(__name__) + + +class RequirementsManager: + """Manages dependency installation by detecting file types""" + + def capture_and_install_dependencies(self, dependencies: Optional[str] = None) -> str: + """Detects the type of file dependencies will be installed from + + If a req.txt or conda.yml file is provided, it verifies their existence and + returns the local file path + + Args: + dependencies (str): Local path where dependencies file exists. + + Returns: + file path of the existing or generated dependencies file + """ + _dependencies = dependencies or self._detect_conda_env_and_local_dependencies + + # Dependencies specified as either req.txt or conda_env.yml + if _dependencies.endswith(".txt"): + self._install_requirements_txt() + elif _dependencies.endswith(".yml"): + self._update_conda_env_in_path() + else: + raise ValueError(f'Invalid dependencies provided: "{_dependencies}"') + + def _install_requirements_txt(self): + """Install requirements.txt file using pip""" + logger.info("Running command to pip install") + subprocess.run("pip install -r in_process_requirements.txt", shell=True, check=True) + logger.info("Command ran successfully") + + def _update_conda_env_in_path(self): + """Update conda env using conda yml file""" + logger.info("Updating conda env") + subprocess.run("conda env update -f conda_in_process.yml", shell=True, check=True) + logger.info("Conda env updated successfully") + + def _get_active_conda_env_name(self) -> str: + """Returns the conda environment name from the set environment variable. None otherwise.""" + return os.getenv("CONDA_DEFAULT_ENV") + + def _get_active_conda_env_prefix(self) -> str: + """Returns the conda prefix from the set environment variable. None otherwise.""" + return os.getenv("CONDA_PREFIX") + + def _detect_conda_env_and_local_dependencies(self) -> str: + """Generates dependencies list from the user's local runtime. + + Raises RuntimeEnvironmentError if not able to. + + Currently supports: conda environments + """ + + # Try to capture dependencies from the conda environment, if any. + conda_env_name = self._get_active_conda_env_name() + logger.info("Found conda_env_name: '%s'", conda_env_name) + conda_env_prefix = None + + if conda_env_name is None: + conda_env_prefix = self._get_active_conda_env_prefix() + + if conda_env_name is None and conda_env_prefix is None: + local_dependencies_path = os.path.join(os.getcwd(), "in_process_requirements.txt") + logger.info(local_dependencies_path) + + return local_dependencies_path + + if conda_env_name == "base": + logger.warning( + "We recommend using an environment other than base to " + "isolate your project dependencies from conda dependencies" + ) + + local_dependencies_path = os.path.join(os.getcwd(), "conda_in_process.yml") + logger.info(local_dependencies_path) + + return local_dependencies_path diff --git a/src/sagemaker/serve/builder/schema_builder.py b/src/sagemaker/serve/builder/schema_builder.py index 3fd1816d0e..7f70e98747 100644 --- a/src/sagemaker/serve/builder/schema_builder.py +++ b/src/sagemaker/serve/builder/schema_builder.py @@ -4,6 +4,7 @@ import io import logging from pathlib import Path +from typing import Callable import numpy as np from pandas import DataFrame @@ -286,7 +287,7 @@ def _is_path_to_file(data: object) -> bool: def _validate_translations( - payload: object, serialize_callable: callable, deserialize_callable: callable + payload: object, serialize_callable: Callable, deserialize_callable: Callable ) -> None: """Placeholder docstring""" try: diff --git a/src/sagemaker/serve/builder/tei_builder.py b/src/sagemaker/serve/builder/tei_builder.py new file mode 100644 index 0000000000..72ecef9448 --- /dev/null +++ b/src/sagemaker/serve/builder/tei_builder.py @@ -0,0 +1,253 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds mixin logic to support deployment of Model ID""" +from __future__ import absolute_import +import logging +from typing import Type +from abc import ABC, abstractmethod + +from sagemaker import image_uris +from sagemaker.model import Model +from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf + +from sagemaker.huggingface import HuggingFaceModel +from sagemaker.serve.utils.local_hardware import ( + _get_nb_instance, +) +from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure +from sagemaker.serve.utils.optimize_utils import _is_optimized +from sagemaker.serve.utils.predictors import InProcessModePredictor, TeiLocalModePredictor +from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.utils.telemetry_logger import _capture_telemetry +from sagemaker.base_predictor import PredictorBase + +logger = logging.getLogger(__name__) +LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS] + +_CODE_FOLDER = "code" + + +class TEI(ABC): + """TEI build logic for ModelBuilder()""" + + def __init__(self): + self.model = None + self.serve_settings = None + self.sagemaker_session = None + self.model_path = None + self.dependencies = None + self.modes = None + self.mode = None + self.model_server = None + self.image_uri = None + self._is_custom_image_uri = False + self.image_config = None + self.vpc_config = None + self._original_deploy = None + self.hf_model_config = None + self._default_tensor_parallel_degree = None + self._default_data_type = None + self._default_max_tokens = None + self.pysdk_model = None + self.schema_builder = None + self.env_vars = None + self.nb_instance_type = None + self.ram_usage_model_load = None + self.secret_key = None + self.role_arn = None + self.name = None + + @abstractmethod + def _prepare_for_mode(self, *args, **kwargs): + """Placeholder docstring""" + + @abstractmethod + def _get_client_translators(self): + """Placeholder docstring""" + + def _set_to_tei(self): + """Placeholder docstring""" + if self.model_server != ModelServer.TEI: + messaging = ( + "HuggingFace Model ID support on model server: " + f"{self.model_server} is not currently supported. " + f"Defaulting to {ModelServer.TEI}" + ) + logger.warning(messaging) + self.model_server = ModelServer.TEI + + def _create_tei_model(self, **kwargs) -> Type[Model]: + """Placeholder docstring""" + if self.nb_instance_type and "instance_type" not in kwargs: + kwargs.update({"instance_type": self.nb_instance_type}) + + if not self.image_uri: + self.image_uri = image_uris.retrieve( + "huggingface-tei", + image_scope="inference", + instance_type=kwargs.get("instance_type"), + region=self.sagemaker_session.boto_region_name, + ) + + pysdk_model = HuggingFaceModel( + image_uri=self.image_uri, + image_config=self.image_config, + vpc_config=self.vpc_config, + env=self.env_vars, + role=self.role_arn, + sagemaker_session=self.sagemaker_session, + name=self.name, + ) + + logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) + + self._original_deploy = pysdk_model.deploy + pysdk_model.deploy = self._tei_model_builder_deploy_wrapper + return pysdk_model + + @_capture_telemetry("tei.deploy") + def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: + """Placeholder docstring""" + timeout = kwargs.get("model_data_download_timeout") + if timeout: + self.pysdk_model.env.update({"MODEL_LOADING_TIMEOUT": str(timeout)}) + + if "mode" in kwargs and kwargs.get("mode") != self.mode: + overwrite_mode = kwargs.get("mode") + # mode overwritten by customer during model.deploy() + logger.warning( + "Deploying in %s Mode, overriding existing configurations set for %s mode", + overwrite_mode, + self.mode, + ) + + if overwrite_mode == Mode.SAGEMAKER_ENDPOINT: + self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT + elif overwrite_mode == Mode.LOCAL_CONTAINER: + self._prepare_for_mode() + self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER + else: + raise ValueError("Mode %s is not supported!" % overwrite_mode) + + serializer = self.schema_builder.input_serializer + deserializer = self.schema_builder._output_deserializer + if self.mode == Mode.IN_PROCESS: + self._prepare_for_mode() + predictor = InProcessModePredictor( + self.modes[str(Mode.IN_PROCESS)], serializer, deserializer + ) + + self.modes[str(Mode.IN_PROCESS)].create_server( + predictor, + ) + return predictor + + if self.mode == Mode.LOCAL_CONTAINER: + timeout = kwargs.get("model_data_download_timeout") + + predictor = TeiLocalModePredictor( + self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer + ) + + self.modes[str(Mode.LOCAL_CONTAINER)].create_server( + self.image_uri, + timeout if timeout else 1800, + None, + predictor, + self.pysdk_model.env, + jumpstart=False, + ) + + return predictor + + if "mode" in kwargs: + del kwargs["mode"] + if "role" in kwargs: + self.pysdk_model.role = kwargs.get("role") + del kwargs["role"] + + if not _is_optimized(self.pysdk_model): + env_vars = {} + if str(Mode.LOCAL_CONTAINER) in self.modes: + # upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT + self.pysdk_model.model_data, env_vars = self._prepare_for_mode( + model_path=self.model_path, should_upload_artifacts=True + ) + else: + _, env_vars = self._prepare_for_mode() + + self.env_vars.update(env_vars) + self.pysdk_model.env.update(self.env_vars) + + # if the weights have been cached via local container mode -> set to offline + if str(Mode.LOCAL_CONTAINER) in self.modes: + self.pysdk_model.env.update({"HF_HUB_OFFLINE": "1"}) + else: + # if has not been built for local container we must use cache + # that hosting has write access to. + self.pysdk_model.env["HF_HOME"] = "/tmp" + self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp" + + if "endpoint_logging" not in kwargs: + kwargs["endpoint_logging"] = True + + if self.nb_instance_type and "instance_type" not in kwargs: + kwargs.update({"instance_type": self.nb_instance_type}) + elif not self.nb_instance_type and "instance_type" not in kwargs: + raise ValueError( + "Instance type must be provided when deploying " "to SageMaker Endpoint mode." + ) + + if "initial_instance_count" not in kwargs: + kwargs.update({"initial_instance_count": 1}) + + predictor = self._original_deploy(*args, **kwargs) + + if "HF_HUB_OFFLINE" in self.pysdk_model.env: + self.pysdk_model.env.update({"HF_HUB_OFFLINE": "0"}) + + predictor.serializer = serializer + predictor.deserializer = deserializer + return predictor + + def _build_for_hf_tei(self): + """Placeholder docstring""" + self.nb_instance_type = _get_nb_instance() + + _create_dir_structure(self.model_path) + if not hasattr(self, "pysdk_model"): + self.env_vars.update({"HF_MODEL_ID": self.model}) + self.hf_model_config = _get_model_config_properties_from_hf( + self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + ) + + self.pysdk_model = self._create_tei_model() + + if self.mode in LOCAL_MODES: + self._prepare_for_mode() + + return self.pysdk_model + + def _build_for_tei(self): + """Placeholder docstring""" + self.secret_key = None + + self._set_to_tei() + + self.pysdk_model = self._build_for_hf_tei() + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session + return self.pysdk_model diff --git a/src/sagemaker/serve/builder/tf_serving_builder.py b/src/sagemaker/serve/builder/tf_serving_builder.py new file mode 100644 index 0000000000..044e0460bc --- /dev/null +++ b/src/sagemaker/serve/builder/tf_serving_builder.py @@ -0,0 +1,135 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds mixin logic to support deployment of Model ID""" +from __future__ import absolute_import +import logging +import os +from pathlib import Path +from abc import ABC, abstractmethod + +from sagemaker import Session +from sagemaker.serve.detector.pickler import save_pkl +from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving +from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor + +logger = logging.getLogger(__name__) + +_TF_SERVING_MODEL_BUILDER_ENTRY_POINT = "inference.py" +_CODE_FOLDER = "code" + + +# pylint: disable=attribute-defined-outside-init, disable=E1101 +class TensorflowServing(ABC): + """TensorflowServing build logic for ModelBuilder()""" + + def __init__(self): + self.model = None + self.serve_settings = None + self.sagemaker_session = None + self.model_path = None + self.dependencies = None + self.modes = None + self.mode = None + self.model_server = None + self.image_uri = None + self._is_custom_image_uri = False + self.image_config = None + self.vpc_config = None + self._original_deploy = None + self.secret_key = None + self.engine = None + self.pysdk_model = None + self.schema_builder = None + self.env_vars = None + self.name = None + + @abstractmethod + def _prepare_for_mode(self): + """Prepare model artifacts based on mode.""" + + @abstractmethod + def _get_client_translators(self): + """Set up client marshaller based on schema builder.""" + + def _save_schema_builder(self): + """Save schema builder for tensorflow serving.""" + if not os.path.exists(self.model_path): + os.makedirs(self.model_path) + + code_path = Path(self.model_path).joinpath("code") + save_pkl(code_path, self.schema_builder) + + def _get_tensorflow_predictor( + self, endpoint_name: str, sagemaker_session: Session + ) -> TensorFlowPredictor: + """Creates a TensorFlowPredictor object""" + serializer, deserializer = self._get_client_translators() + + return TensorFlowPredictor( + endpoint_name=endpoint_name, + sagemaker_session=sagemaker_session, + serializer=serializer, + deserializer=deserializer, + ) + + def _validate_for_tensorflow_serving(self): + """Validate for tensorflow serving""" + if not getattr(self, "_is_mlflow_model", False): + raise ValueError("Tensorflow Serving is currently only supported for mlflow models.") + + def _create_tensorflow_model(self): + """Creates a TensorFlow model object""" + self.pysdk_model = TensorFlowModel( + image_uri=self.image_uri, + image_config=self.image_config, + vpc_config=self.vpc_config, + model_data=self.s3_upload_path, + role=self.serve_settings.role_arn, + env=self.env_vars, + sagemaker_session=self.sagemaker_session, + predictor_cls=self._get_tensorflow_predictor, + name=self.name, + ) + + self.pysdk_model.mode = self.mode + self.pysdk_model.modes = self.modes + self.pysdk_model.serve_settings = self.serve_settings + if hasattr(self, "role_arn") and self.role_arn: + self.pysdk_model.role = self.role_arn + if hasattr(self, "sagemaker_session") and self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session + + self._original_deploy = self.pysdk_model.deploy + self.pysdk_model.deploy = self._model_builder_deploy_wrapper + self._original_register = self.pysdk_model.register + self.pysdk_model.register = self._model_builder_register_wrapper + self.model_package = None + return self.pysdk_model + + def _build_for_tensorflow_serving(self): + """Build the model for Tensorflow Serving""" + self._validate_for_tensorflow_serving() + self._save_schema_builder() + + if not self.image_uri: + raise ValueError("image_uri is not set for tensorflow serving") + + self.secret_key = prepare_for_tf_serving( + model_path=self.model_path, + shared_libs=self.shared_libs, + dependencies=self.dependencies, + ) + + self._prepare_for_mode() + + return self._create_tensorflow_model() diff --git a/src/sagemaker/serve/builder/tgi_builder.py b/src/sagemaker/serve/builder/tgi_builder.py index 23cc7e2202..032056cfec 100644 --- a/src/sagemaker/serve/builder/tgi_builder.py +++ b/src/sagemaker/serve/builder/tgi_builder.py @@ -25,13 +25,14 @@ LocalModelInvocationException, SkipTuningComboException, ) +from sagemaker.serve.utils.optimize_utils import _is_optimized from sagemaker.serve.utils.tuning import ( _serial_benchmark, _concurrent_benchmark, _more_performant, _pretty_print_results_tgi, ) -from sagemaker.djl_inference.model import _get_model_config_properties_from_hf +from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf from sagemaker.serve.model_server.djl_serving.utils import ( _get_admissible_tensor_parallel_degrees, _get_default_tensor_parallel_degree, @@ -48,13 +49,14 @@ _get_gpu_info_fallback, ) from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure -from sagemaker.serve.utils.predictors import TgiLocalModePredictor +from sagemaker.serve.utils.predictors import TgiLocalModePredictor, InProcessModePredictor from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.telemetry_logger import _capture_telemetry from sagemaker.base_predictor import PredictorBase logger = logging.getLogger(__name__) +LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS] _CODE_FOLDER = "code" _INVALID_SAMPLE_DATA_EX = ( @@ -90,11 +92,11 @@ def __init__(self): self.nb_instance_type = None self.ram_usage_model_load = None self.secret_key = None - self.jumpstart = None self.role_arn = None + self.name = None @abstractmethod - def _prepare_for_mode(self): + def _prepare_for_mode(self, *args, **kwargs): """Placeholder docstring""" @abstractmethod @@ -142,6 +144,7 @@ def _create_tgi_model(self) -> Type[Model]: env=self.env_vars, role=self.role_arn, sagemaker_session=self.sagemaker_session, + name=self.name, ) self._original_deploy = pysdk_model.deploy @@ -174,6 +177,17 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa serializer = self.schema_builder.input_serializer deserializer = self.schema_builder._output_deserializer + + if self.mode == Mode.IN_PROCESS: + predictor = InProcessModePredictor( + self.modes[str(Mode.IN_PROCESS)], serializer, deserializer + ) + + self.modes[str(Mode.IN_PROCESS)].create_server( + predictor, + ) + return predictor + if self.mode == Mode.LOCAL_CONTAINER: timeout = kwargs.get("model_data_download_timeout") @@ -202,18 +216,26 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa self.pysdk_model.role = kwargs.get("role") del kwargs["role"] - # set model_data to uncompressed s3 dict - self.pysdk_model.model_data, env_vars = self._prepare_for_mode() - self.env_vars.update(env_vars) - self.pysdk_model.env.update(self.env_vars) + if not _is_optimized(self.pysdk_model): + env_vars = {} + if str(Mode.LOCAL_CONTAINER) in self.modes: + # upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT + self.pysdk_model.model_data, env_vars = self._prepare_for_mode( + model_path=self.model_path, should_upload_artifacts=True + ) + else: + _, env_vars = self._prepare_for_mode() + + self.env_vars.update(env_vars) + self.pysdk_model.env.update(self.env_vars) # if the weights have been cached via local container mode -> set to offline if str(Mode.LOCAL_CONTAINER) in self.modes: - self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"}) + self.pysdk_model.env.update({"HF_HUB_OFFLINE": "1"}) else: # if has not been built for local container we must use cache # that hosting has write access to. - self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp" + self.pysdk_model.env["HF_HOME"] = "/tmp" self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp" if "endpoint_logging" not in kwargs: @@ -243,7 +265,8 @@ def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa predictor = self._original_deploy(*args, **kwargs) - self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "0"}) + if "HF_HUB_OFFLINE" in self.pysdk_model.env: + self.pysdk_model.env.update({"HF_HUB_OFFLINE": "0"}) predictor.serializer = serializer predictor.deserializer = deserializer @@ -269,7 +292,7 @@ def _build_for_hf_tgi(self): ] = _default_max_new_tokens self.pysdk_model = self._create_tgi_model() - if self.mode == Mode.LOCAL_CONTAINER: + if self.mode in LOCAL_MODES: self._prepare_for_mode() return self.pysdk_model @@ -473,4 +496,8 @@ def _build_for_tgi(self): self.pysdk_model = self._build_for_hf_tgi() self.pysdk_model.tune = self._tune_for_hf_tgi + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index 3d84e314df..0388a9a05d 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -13,8 +13,11 @@ """Transformers build logic with model builder""" from __future__ import absolute_import import logging +import os from abc import ABC, abstractmethod from typing import Type +from pathlib import Path +import subprocess from packaging.version import Version from sagemaker.model import Model @@ -22,20 +25,32 @@ from sagemaker.serve.utils.local_hardware import ( _get_nb_instance, ) -from sagemaker.djl_inference.model import _get_model_config_properties_from_hf +from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf from sagemaker.huggingface import HuggingFaceModel from sagemaker.serve.model_server.multi_model_server.prepare import ( _create_dir_structure, + prepare_for_mms, +) +from sagemaker.serve.detector.image_detector import ( + auto_detect_container, +) +from sagemaker.serve.detector.pickler import save_pkl +from sagemaker.serve.utils.optimize_utils import _is_optimized +from sagemaker.serve.utils.predictors import ( + TransformersLocalModePredictor, + InProcessModePredictor, ) -from sagemaker.serve.utils.predictors import TransformersLocalModePredictor from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.telemetry_logger import _capture_telemetry from sagemaker.base_predictor import PredictorBase from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata +from sagemaker.serve.builder.requirements_manager import RequirementsManager + logger = logging.getLogger(__name__) DEFAULT_TIMEOUT = 1800 +LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS] """Retrieves images for different libraries - Pytorch, TensorFlow from HuggingFace hub @@ -72,12 +87,35 @@ def __init__(self): self.pytorch_version = None self.instance_type = None self.schema_builder = None + self.inference_spec = None + self.shared_libs = None + self.name = None @abstractmethod - def _prepare_for_mode(self): + def _prepare_for_mode(self, *args, **kwargs): """Abstract method""" def _create_transformers_model(self) -> Type[Model]: + """Initializes HF model with or without image_uri""" + if self.image_uri is None: + pysdk_model = self._get_hf_metadata_create_model() + else: + pysdk_model = HuggingFaceModel( + image_uri=self.image_uri, + vpc_config=self.vpc_config, + env=self.env_vars, + role=self.role_arn, + sagemaker_session=self.sagemaker_session, + name=self.name, + ) + + logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) + + self._original_deploy = pysdk_model.deploy + pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper + return pysdk_model + + def _get_hf_metadata_create_model(self) -> Type[Model]: """Initializes the model after fetching image 1. Get the metadata for deciding framework @@ -90,7 +128,7 @@ def _create_transformers_model(self) -> Type[Model]: """ hf_model_md = get_huggingface_model_metadata( - self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) hf_config = image_uris.config_for_framework("huggingface").get("inference") config = hf_config["versions"] @@ -132,19 +170,21 @@ def _create_transformers_model(self) -> Type[Model]: vpc_config=self.vpc_config, ) - if self.mode == Mode.LOCAL_CONTAINER: + if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER: self.image_uri = pysdk_model.serving_image_uri( self.sagemaker_session.boto_region_name, "local" ) - else: + elif not self.image_uri: self.image_uri = pysdk_model.serving_image_uri( self.sagemaker_session.boto_region_name, self.instance_type ) - logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri) + if pysdk_model is None or self.image_uri is None: + raise ValueError("PySDK model unable to be created, try overriding image_uri") + + if not pysdk_model.image_uri: + pysdk_model.image_uri = self.image_uri - self._original_deploy = pysdk_model.deploy - pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper return pysdk_model @_capture_telemetry("transformers.deploy") @@ -175,8 +215,6 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr else: raise ValueError("Mode %s is not supported!" % overwrite_mode) - self._set_instance() - serializer = self.schema_builder.input_serializer deserializer = self.schema_builder._output_deserializer if self.mode == Mode.LOCAL_CONTAINER: @@ -196,16 +234,44 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr ) return predictor + if self.mode == Mode.IN_PROCESS: + timeout = kwargs.get("model_data_download_timeout") + + predictor = InProcessModePredictor( + self.modes[str(Mode.IN_PROCESS)], serializer, deserializer + ) + + self.modes[str(Mode.IN_PROCESS)].create_server( + predictor, + ) + return predictor + + self._set_instance(kwargs) + if "mode" in kwargs: del kwargs["mode"] if "role" in kwargs: self.pysdk_model.role = kwargs.get("role") del kwargs["role"] - # set model_data to uncompressed s3 dict - self.pysdk_model.model_data, env_vars = self._prepare_for_mode() - self.env_vars.update(env_vars) - self.pysdk_model.env.update(self.env_vars) + if not _is_optimized(self.pysdk_model): + env_vars = {} + if str(Mode.LOCAL_CONTAINER) in self.modes: + # upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT + self.pysdk_model.model_data, env_vars = self._prepare_for_mode( + model_path=self.model_path, should_upload_artifacts=True + ) + else: + _, env_vars = self._prepare_for_mode() + + self.env_vars.update(env_vars) + self.pysdk_model.env.update(self.env_vars) + + if ( + "SAGEMAKER_SERVE_SECRET_KEY" in self.pysdk_model.env + and not self.pysdk_model.env["SAGEMAKER_SERVE_SECRET_KEY"] + ): + del self.pysdk_model.env["SAGEMAKER_SERVE_SECRET_KEY"] if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True @@ -225,39 +291,46 @@ def _build_transformers_env(self): _create_dir_structure(self.model_path) if not hasattr(self, "pysdk_model"): - self.env_vars.update({"HF_MODEL_ID": self.model}) + + if self.inference_spec is not None: + self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()}) + else: + self.env_vars.update({"HF_MODEL_ID": self.model}) logger.info(self.env_vars) # TODO: Move to a helper function if hasattr(self.env_vars, "HF_API_TOKEN"): self.hf_model_config = _get_model_config_properties_from_hf( - self.model, self.env_vars.get("HF_API_TOKEN") + self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_API_TOKEN") ) else: self.hf_model_config = _get_model_config_properties_from_hf( - self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") + self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) self.pysdk_model = self._create_transformers_model() - if self.mode == Mode.LOCAL_CONTAINER: + if self.mode in LOCAL_MODES: self._prepare_for_mode() return self.pysdk_model - def _set_instance(self, **kwargs): + def _set_instance(self, kwargs): """Set the instance : Given the detected notebook type or provided instance type""" if self.mode == Mode.SAGEMAKER_ENDPOINT: + if "instance_type" in kwargs: + return if self.nb_instance_type and "instance_type" not in kwargs: kwargs.update({"instance_type": self.nb_instance_type}) + logger.info("Setting instance type to %s", self.nb_instance_type) elif self.instance_type and "instance_type" not in kwargs: kwargs.update({"instance_type": self.instance_type}) + logger.info("Setting instance type to %s", self.instance_type) else: raise ValueError( "Instance type must be provided when deploying to SageMaker Endpoint mode." ) - logger.info("Setting instance type to %s", self.instance_type) def _get_supported_version(self, hf_config, hugging_face_version, base_fw): """Uses the hugging face json config to pick supported versions""" @@ -271,6 +344,42 @@ def _get_supported_version(self, hf_config, hugging_face_version, base_fw): versions_to_return.append(base_fw_version) return sorted(versions_to_return, reverse=True)[0] + def _auto_detect_container(self): + """Set image_uri by detecting container via model name or inference spec""" + # Auto detect the container image uri + if self.image_uri: + logger.info( + "Skipping auto detection as the image uri is provided %s", + self.image_uri, + ) + return + + if self.model: + logger.info( + "Auto detect container url for the provided model and on instance %s", + self.instance_type, + ) + self.image_uri = auto_detect_container( + self.model, self.sagemaker_session.boto_region_name, self.instance_type + ) + + elif self.inference_spec: + # TODO: this won't work for larger image. + # Fail and let the customer include the image uri + logger.warning( + "model_path provided with no image_uri. Attempting to autodetect the image\ + by loading the model using inference_spec.load()..." + ) + self.image_uri = auto_detect_container( + self.inference_spec.load(self.model_path), + self.sagemaker_session.boto_region_name, + self.instance_type, + ) + else: + raise ValueError( + "Cannot detect and set image_uri. Please pass model or inference spec." + ) + def _build_for_transformers(self): """Method that triggers model build @@ -279,6 +388,41 @@ def _build_for_transformers(self): self.secret_key = None self.model_server = ModelServer.MMS + if self.inference_spec: + + os.makedirs(self.model_path, exist_ok=True) + + code_path = Path(self.model_path).joinpath("code") + + save_pkl(code_path, (self.inference_spec, self.schema_builder)) + logger.info("PKL file saved to file: %s", code_path) + + if self.mode == Mode.IN_PROCESS: + self._create_conda_env() + + self._auto_detect_container() + + self.secret_key = prepare_for_mms( + model_path=self.model_path, + shared_libs=self.shared_libs, + dependencies=self.dependencies, + session=self.sagemaker_session, + image_uri=self.image_uri, + inference_spec=self.inference_spec, + ) + self._build_transformers_env() + if self.role_arn: + self.pysdk_model.role = self.role_arn + if self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session return self.pysdk_model + + def _create_conda_env(self): + """Creating conda environment by running commands""" + + try: + RequirementsManager().capture_and_install_dependencies + except subprocess.CalledProcessError: + print("Failed to create and activate conda environment.") diff --git a/src/sagemaker/serve/detector/dependency_manager.py b/src/sagemaker/serve/detector/dependency_manager.py index e72a84da30..8ff37c9185 100644 --- a/src/sagemaker/serve/detector/dependency_manager.py +++ b/src/sagemaker/serve/detector/dependency_manager.py @@ -34,22 +34,34 @@ def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool = """Placeholder docstring""" path = work_dir.joinpath("requirements.txt") if "auto" in dependencies and dependencies["auto"]: + import site + + pkl_path = work_dir.joinpath(PKL_FILE_NAME) + dest_path = path + site_packages_dir = site.getsitepackages()[0] + pickle_command_dir = "/sagemaker/serve/detector" + command = [ sys.executable, - Path(__file__).parent.joinpath("pickle_dependencies.py"), - "--pkl_path", - work_dir.joinpath(PKL_FILE_NAME), - "--dest", - path, + "-c", ] if capture_all: - command.append("--capture_all") + command.append( + f"from pickle_dependencies import get_all_requirements;" + f'get_all_requirements("{dest_path}")' + ) + else: + command.append( + f"from pickle_dependencies import get_requirements_for_pkl_file;" + f'get_requirements_for_pkl_file("{pkl_path}", "{dest_path}")' + ) subprocess.run( command, env={"SETUPTOOLS_USE_DISTUTILS": "stdlib"}, check=True, + cwd=site_packages_dir + pickle_command_dir, ) with open(path, "r") as f: diff --git a/src/sagemaker/serve/detector/image_detector.py b/src/sagemaker/serve/detector/image_detector.py index 63831f5950..d8bee9deb8 100644 --- a/src/sagemaker/serve/detector/image_detector.py +++ b/src/sagemaker/serve/detector/image_detector.py @@ -43,7 +43,7 @@ def auto_detect_container(model, region: str, instance_type: str) -> str: casted_versions = _cast_to_compatible_version(fw, fw_version) if fw_version else (None,) dlc = None - for casted_version in casted_versions: + for casted_version in filter(None, casted_versions): try: dlc = image_uris.retrieve( framework=fw, diff --git a/src/sagemaker/serve/detector/pickle_dependencies.py b/src/sagemaker/serve/detector/pickle_dependencies.py index 5a1cd43869..8f9da917fd 100644 --- a/src/sagemaker/serve/detector/pickle_dependencies.py +++ b/src/sagemaker/serve/detector/pickle_dependencies.py @@ -3,7 +3,6 @@ from __future__ import absolute_import from pathlib import Path from typing import List -import argparse import email.parser import email.policy import json @@ -129,32 +128,3 @@ def get_all_requirements(dest: Path): version = package_info.get("version") out.write(f"{name}=={version}\n") - - -def parse_args(): - """Placeholder docstring""" - parser = argparse.ArgumentParser( - prog="pkl_requirements", description="Generates a requirements.txt for a cloudpickle file" - ) - parser.add_argument("--pkl_path", required=True, help="path of the pkl file") - parser.add_argument("--dest", required=True, help="path of the destination requirements.txt") - parser.add_argument( - "--capture_all", - action="store_true", - help="capture all dependencies in current environment", - ) - args = parser.parse_args() - return (Path(args.pkl_path), Path(args.dest), args.capture_all) - - -def main(): - """Placeholder docstring""" - pkl_path, dest, capture_all = parse_args() - if capture_all: - get_all_requirements(dest) - else: - get_requirements_for_pkl_file(pkl_path, dest) - - -if __name__ == "__main__": - main() diff --git a/src/sagemaker/serve/mode/in_process_mode.py b/src/sagemaker/serve/mode/in_process_mode.py new file mode 100644 index 0000000000..0c262da6f3 --- /dev/null +++ b/src/sagemaker/serve/mode/in_process_mode.py @@ -0,0 +1,95 @@ +"""Module that defines the InProcessMode class""" + +from __future__ import absolute_import + +from pathlib import Path +import logging +from typing import Dict, Type, Optional +import time +from datetime import datetime, timedelta + +from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.serve.utils.exceptions import InProcessDeepPingException +from sagemaker.serve.model_server.in_process_model_server.in_process_server import InProcessServing +from sagemaker.session import Session + +logger = logging.getLogger(__name__) + +_PING_HEALTH_CHECK_FAIL_MSG = "Ping health check did not pass. Please review your inference code." + + +class InProcessMode(InProcessServing): + """A class that holds methods to deploy model to a container in process environment""" + + def __init__( + self, + model: Optional[str], + inference_spec: Optional[InferenceSpec], + schema_builder: Type[SchemaBuilder], + session: Session, + model_path: str = None, + env_vars: Dict = None, + ): + # pylint: disable=bad-super-call + super().__init__() + + self.model = model + self.inference_spec = inference_spec + self.model_path = model_path + self.env_vars = env_vars + self.session = session + self.schema_builder = schema_builder + self._ping_local_server = None + + def load(self, model_path: str = None): + """Loads model path, checks that path exists""" + path = Path(model_path if model_path else self.model_path) + if not path.exists(): + raise ValueError("model_path does not exist") + if not path.is_dir(): + raise ValueError("model_path is not a valid directory") + + return self.inference_spec.load(str(path)) + + def prepare(self): + """Prepares the server""" + + def create_server( + self, + predictor: PredictorBase, + ): + """Creating the fast api server and checking ping health.""" + + logger.info("Waiting for fastapi server to start up...") + + logger.warning("Note: This is not a standard model server.") + logger.warning("The model is being hosted directly on the FastAPI server.") + + self._ping_local_server = self._deep_ping + self._start_serving() + + # allow some time for server to be ready. + time.sleep(1) + + time_limit = datetime.now() + timedelta(seconds=5) + healthy = True + while True: + final_pull = datetime.now() > time_limit + if final_pull: + break + + healthy, response = self._ping_local_server(predictor) + if healthy: + logger.debug("Ping health check has passed. Returned %s", str(response)) + break + + time.sleep(1) + + if not healthy: + raise InProcessDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG) + + def destroy_server(self): + """Placeholder docstring""" + self._stop_serving() diff --git a/src/sagemaker/serve/mode/local_container_mode.py b/src/sagemaker/serve/mode/local_container_mode.py index 362a3804de..f040c61c1d 100644 --- a/src/sagemaker/serve/mode/local_container_mode.py +++ b/src/sagemaker/serve/mode/local_container_mode.py @@ -11,6 +11,7 @@ import docker from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.model_server.tensorflow_serving.server import LocalTensorflowServing from sagemaker.serve.spec.inference_spec import InferenceSpec from sagemaker.serve.builder.schema_builder import SchemaBuilder from sagemaker.serve.utils.logging_agent import pull_logs @@ -20,6 +21,7 @@ from sagemaker.serve.model_server.djl_serving.server import LocalDJLServing from sagemaker.serve.model_server.triton.server import LocalTritonServer from sagemaker.serve.model_server.tgi.server import LocalTgiServing +from sagemaker.serve.model_server.tei.server import LocalTeiServing from sagemaker.serve.model_server.multi_model_server.server import LocalMultiModelServer from sagemaker.session import Session @@ -34,7 +36,12 @@ class LocalContainerMode( - LocalTorchServe, LocalDJLServing, LocalTritonServer, LocalTgiServing, LocalMultiModelServer + LocalTorchServe, + LocalDJLServing, + LocalTritonServer, + LocalTgiServing, + LocalMultiModelServer, + LocalTensorflowServing, ): """A class that holds methods to deploy model to a container in local environment""" @@ -63,6 +70,7 @@ def __init__( self.container = None self.secret_key = None self._ping_container = None + self._invoke_serving = None def load(self, model_path: str = None): """Placeholder docstring""" @@ -141,6 +149,28 @@ def create_server( env_vars=env_vars if env_vars else self.env_vars, ) self._ping_container = self._multi_model_server_deep_ping + elif self.model_server == ModelServer.TENSORFLOW_SERVING: + self._start_tensorflow_serving( + client=self.client, + image=image, + model_path=model_path if model_path else self.model_path, + secret_key=secret_key, + env_vars=env_vars if env_vars else self.env_vars, + ) + self._ping_container = self._tensorflow_serving_deep_ping + elif self.model_server == ModelServer.TEI: + tei_serving = LocalTeiServing() + tei_serving._start_tei_serving( + client=self.client, + image=image, + model_path=model_path if model_path else self.model_path, + secret_key=secret_key, + env_vars=env_vars if env_vars else self.env_vars, + ) + tei_serving.schema_builder = self.schema_builder + self.container = tei_serving.container + self._ping_container = tei_serving._tei_deep_ping + self._invoke_serving = tei_serving._invoke_tei_serving # allow some time for container to be ready time.sleep(10) diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index 0fdc425b31..2b4473a706 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -6,6 +6,8 @@ import logging from typing import Type +from sagemaker.serve.model_server.tei.server import SageMakerTeiServing +from sagemaker.serve.model_server.tensorflow_serving.server import SageMakerTensorflowServing from sagemaker.session import Session from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.spec.inference_spec import InferenceSpec @@ -14,16 +16,21 @@ from sagemaker.serve.model_server.djl_serving.server import SageMakerDjlServing from sagemaker.serve.model_server.tgi.server import SageMakerTgiServing from sagemaker.serve.model_server.multi_model_server.server import SageMakerMultiModelServer +from sagemaker.serve.model_server.smd.server import SageMakerSmdServer + logger = logging.getLogger(__name__) +# pylint: disable=R0901 class SageMakerEndpointMode( SageMakerTorchServe, SageMakerTritonServer, SageMakerDjlServing, SageMakerTgiServing, SageMakerMultiModelServer, + SageMakerTensorflowServing, + SageMakerSmdServer, ): """Holds the required method to deploy a model to a SageMaker Endpoint""" @@ -35,6 +42,8 @@ def __init__(self, inference_spec: Type[InferenceSpec], model_server: ModelServe self.inference_spec = inference_spec self.model_server = model_server + self._tei_serving = SageMakerTeiServing() + def load(self, model_path: str): """Placeholder docstring""" path = Path(model_path) @@ -54,6 +63,7 @@ def prepare( sagemaker_session: Session = None, image: str = None, jumpstart: bool = False, + should_upload_artifacts: bool = False, ): """Placeholder docstring""" try: @@ -64,47 +74,91 @@ def prepare( + "session to be created or supply `sagemaker_session` into @serve.invoke." ) from e + upload_artifacts = None, None if self.model_server == ModelServer.TORCHSERVE: - return self._upload_torchserve_artifacts( + upload_artifacts = self._upload_torchserve_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, secret_key=secret_key, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=True, ) if self.model_server == ModelServer.TRITON: - return self._upload_triton_artifacts( + upload_artifacts = self._upload_triton_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, secret_key=secret_key, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=True, ) if self.model_server == ModelServer.DJL_SERVING: - return self._upload_djl_artifacts( + upload_artifacts = self._upload_djl_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, image=image, + should_upload_artifacts=True, ) + if self.model_server == ModelServer.TENSORFLOW_SERVING: + upload_artifacts = self._upload_tensorflow_serving_artifacts( + model_path=model_path, + sagemaker_session=sagemaker_session, + secret_key=secret_key, + s3_model_data_url=s3_model_data_url, + image=image, + should_upload_artifacts=True, + ) + + # By default, we do not want to upload artifacts in S3 for the below server. + # In Case of Optimization, artifacts need to be uploaded into s3. + # In that case, `should_upload_artifacts` arg needs to come from + # the caller of prepare. + if self.model_server == ModelServer.TGI: - return self._upload_tgi_artifacts( + upload_artifacts = self._upload_tgi_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, image=image, jumpstart=jumpstart, + should_upload_artifacts=should_upload_artifacts, ) if self.model_server == ModelServer.MMS: - return self._upload_server_artifacts( + upload_artifacts = self._upload_server_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, s3_model_data_url=s3_model_data_url, + secret_key=secret_key, image=image, + should_upload_artifacts=should_upload_artifacts, ) + if self.model_server == ModelServer.TEI: + upload_artifacts = self._tei_serving._upload_tei_artifacts( + model_path=model_path, + sagemaker_session=sagemaker_session, + s3_model_data_url=s3_model_data_url, + image=image, + should_upload_artifacts=should_upload_artifacts, + ) + + if self.model_server == ModelServer.SMD: + upload_artifacts = self._upload_smd_artifacts( + model_path=model_path, + sagemaker_session=sagemaker_session, + secret_key=secret_key, + s3_model_data_url=s3_model_data_url, + image=image, + should_upload_artifacts=True, + ) + + if upload_artifacts or isinstance(self.model_server, ModelServer): + return upload_artifacts + raise ValueError("%s model server is not supported" % self.model_server) diff --git a/src/sagemaker/serve/model_format/mlflow/constants.py b/src/sagemaker/serve/model_format/mlflow/constants.py index 00ef76170c..ff7553ea5f 100644 --- a/src/sagemaker/serve/model_format/mlflow/constants.py +++ b/src/sagemaker/serve/model_format/mlflow/constants.py @@ -18,7 +18,15 @@ "py38": "1.12.1", "py39": "1.13.1", "py310": "2.2.0", + "py311": "2.3.0", } +MODEL_PACKAGE_ARN_REGEX = ( + r"^arn:aws:sagemaker:[a-z0-9\-]+:[0-9]{12}:model-package\/(.*?)(?:/(\d+))?$" +) +MLFLOW_RUN_ID_REGEX = r"^runs:/[a-zA-Z0-9]+(/[a-zA-Z0-9\-_\.]*)+$" +MLFLOW_REGISTRY_PATH_REGEX = r"^models:/[a-zA-Z0-9\-_\.]+[@/]?[a-zA-Z0-9\-_\.][/a-zA-Z0-9\-_\.]*$" +S3_PATH_REGEX = r"^s3:\/\/[a-zA-Z0-9\-_\.]+(?:\/[a-zA-Z0-9\-_\/\.]*)?$" +MLFLOW_TRACKING_ARN = "MLFLOW_TRACKING_ARN" MLFLOW_MODEL_PATH = "MLFLOW_MODEL_PATH" MLFLOW_METADATA_FILE = "MLmodel" MLFLOW_PIP_DEPENDENCY_FILE = "requirements.txt" @@ -34,8 +42,12 @@ "spark": "pyspark", "onnx": "onnxruntime", } -FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = [ # will extend to keras and tf - "sklearn", - "pytorch", - "xgboost", -] +TENSORFLOW_SAVED_MODEL_NAME = "saved_model.pb" +FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT = { + "sklearn": "sklearn", + "pytorch": "pytorch", + "xgboost": "xgboost", + "tensorflow": "tensorflow", + "keras": "tensorflow", +} +FLAVORS_DEFAULT_WITH_TF_SERVING = ["keras", "tensorflow"] diff --git a/src/sagemaker/serve/model_format/mlflow/utils.py b/src/sagemaker/serve/model_format/mlflow/utils.py index c9a8093a79..69082fe575 100644 --- a/src/sagemaker/serve/model_format/mlflow/utils.py +++ b/src/sagemaker/serve/model_format/mlflow/utils.py @@ -13,7 +13,8 @@ """Holds the util functions used for MLflow model format""" from __future__ import absolute_import -from typing import Optional, Dict, Any +from pathlib import Path +from typing import Optional, Dict, Any, Union import yaml import logging import shutil @@ -30,6 +31,8 @@ DEFAULT_PYTORCH_VERSION, MLFLOW_METADATA_FILE, MLFLOW_PIP_DEPENDENCY_FILE, + FLAVORS_DEFAULT_WITH_TF_SERVING, + TENSORFLOW_SAVED_MODEL_NAME, ) logger = logging.getLogger(__name__) @@ -44,7 +47,8 @@ def _get_default_model_server_for_mlflow(deployment_flavor: str) -> ModelServer: Returns: str: The model server chosen for given model flavor. """ - # TODO: implement real logic here based on mlflow flavor + if deployment_flavor in FLAVORS_DEFAULT_WITH_TF_SERVING: + return ModelServer.TENSORFLOW_SERVING return ModelServer.TORCHSERVE @@ -223,28 +227,6 @@ def _get_python_version_from_parsed_mlflow_model_file( raise ValueError(f"{MLFLOW_PYFUNC} cannot be found in MLmodel file.") -def _mlflow_input_is_local_path(model_path: str) -> bool: - """Checks if the given model_path is a local filesystem path. - - Args: - - model_path (str): The model path to check. - - Returns: - - bool: True if model_path is a local path, False otherwise. - """ - if model_path.startswith("s3://"): - return False - - if "/runs/" in model_path or model_path.startswith("runs:"): - return False - - # Check if it's not a local file path - if not os.path.exists(model_path): - return False - - return True - - def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> None: """Downloads all artifacts from a specified S3 path to a local destination path. @@ -274,7 +256,7 @@ def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> Non os.makedirs(local_file_dir, exist_ok=True) # Download the file - print(f"Downloading {key} to {local_file_path}") + logger.info(f"Downloading {key} to {local_file_path}") s3.download_file(s3_bucket, key, local_file_path) @@ -344,24 +326,34 @@ def _select_container_for_mlflow_model( f"specific DLC support. Defaulting to generic image..." ) return _get_default_image_for_mlflow(python_version, region, instance_type) - framework_version = _get_framework_version_from_requirements( - deployment_flavor, requirement_path - ) + + framework_to_use = FLAVORS_WITH_FRAMEWORK_SPECIFIC_DLC_SUPPORT.get(deployment_flavor) + framework_version = _get_framework_version_from_requirements(framework_to_use, requirement_path) logger.info("Auto-detected deployment flavor is %s", deployment_flavor) + logger.info("Auto-detected framework to use is %s", framework_to_use) logger.info("Auto-detected framework version is %s", framework_version) + if framework_version is None: + raise ValueError( + ( + "Unable to auto detect framework version. Please provide framework %s as part of the " + "requirements.txt file for deployment flavor %s" + ) + % (framework_to_use, deployment_flavor) + ) + casted_versions = ( - _cast_to_compatible_version(deployment_flavor, framework_version) + _cast_to_compatible_version(framework_to_use, framework_version) if framework_version else (None,) ) image_uri = None - for casted_version in casted_versions: + for casted_version in filter(None, casted_versions): try: image_uri = image_uris.retrieve( - framework=deployment_flavor, + framework=framework_to_use, region=region, version=casted_version, image_scope="inference", @@ -392,17 +384,60 @@ def _select_container_for_mlflow_model( ) -def _validate_input_for_mlflow(model_server: ModelServer) -> None: +def _validate_input_for_mlflow(model_server: ModelServer, deployment_flavor: str) -> None: """Validates arguments provided with mlflow models. Args: - model_server (ModelServer): Model server used for orchestrating mlflow model. + - deployment_flavor (str): The flavor mlflow model will be deployed with. Raises: - ValueError: If model server is not torchserve. """ - if model_server != ModelServer.TORCHSERVE: + if model_server != ModelServer.TORCHSERVE and model_server != ModelServer.TENSORFLOW_SERVING: raise ValueError( f"{model_server} is currently not supported for MLflow Model. " f"Please choose another model server." ) + if ( + model_server == ModelServer.TENSORFLOW_SERVING + and deployment_flavor not in FLAVORS_DEFAULT_WITH_TF_SERVING + ): + raise ValueError( + "Tensorflow Serving is currently only supported for the following " + "deployment flavors: {}".format(FLAVORS_DEFAULT_WITH_TF_SERVING) + ) + + +def _get_saved_model_path_for_tensorflow_and_keras_flavor(model_path: str) -> Optional[str]: + """Recursively searches for tensorflow saved model. + + Args: + model_path (str): The root directory to start the search from. + + Returns: + Optional[str]: The absolute path to the directory containing 'saved_model.pb'. + """ + for dirpath, dirnames, filenames in os.walk(model_path): + if TENSORFLOW_SAVED_MODEL_NAME in filenames: + return os.path.abspath(dirpath) + + return None + + +def _move_contents(src_dir: Union[str, Path], dest_dir: Union[str, Path]) -> None: + """Moves all contents of a source directory to a specified destination directory. + + Args: + src_dir (Union[str, Path]): The path to the source directory. + dest_dir (Union[str, Path]): The path to the destination directory. + + """ + _src_dir = Path(os.path.normpath(src_dir)) + _dest_dir = Path(os.path.normpath(dest_dir)) + + _dest_dir.mkdir(parents=True, exist_ok=True) + + for item in _src_dir.iterdir(): + _dest_path = _dest_dir / item.name + shutil.move(str(item), str(_dest_path)) diff --git a/src/sagemaker/serve/model_server/djl_serving/inference.py b/src/sagemaker/serve/model_server/djl_serving/inference.py deleted file mode 100644 index 2dba9eb877..0000000000 --- a/src/sagemaker/serve/model_server/djl_serving/inference.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""DJL Handler Template - -Getting Started DJL Handle provided via ModelBuilder. -Feel free to re-purpose this script for your DJL usecase -and re-deploy via ModelBuilder().deploy(). -""" -from __future__ import absolute_import - -from djl_python.inputs import Input -from djl_python.outputs import Output - - -class HandleTemplate: - """A DJL Handler class template that uses the default DeepSpeed, FasterTransformer, and HuggingFaceAccelerate Handlers - - Reference the default handlers here: - - https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/deepspeed.py - - https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/fastertransformer.py - - https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/huggingface.py - """ - - def __init__(self): - self.initialized = False - self.handle = None - - def initialize(self, inputs: Input): - """Template method to load you model with specified engine.""" - self.initialized = True - - if "DeepSpeed" == inputs.get_property("engine"): - from djl_python.deepspeed import handle - elif "FasterTransformer" == inputs.get_property("engine"): - from djl_python.fastertransformer import handle - else: - from djl_python.huggingface import handle - - self._handle = handle - - def inference(self, inputs: Input): - """Template method used to invoke the model. Please implement this if you'd like to construct your own script""" - - -_handle_template = HandleTemplate() - - -def handle(inputs: Input) -> Output: - """Driver function required by djl-serving""" - if not _handle_template.initialized: - _handle_template.initialize(inputs) - - return _handle_template._handle(inputs) diff --git a/src/sagemaker/serve/model_server/djl_serving/prepare.py b/src/sagemaker/serve/model_server/djl_serving/prepare.py index 810acc8aff..40cb04152c 100644 --- a/src/sagemaker/serve/model_server/djl_serving/prepare.py +++ b/src/sagemaker/serve/model_server/djl_serving/prepare.py @@ -13,7 +13,6 @@ """Prepare DjlModel for Deployment""" from __future__ import absolute_import -import shutil import json import tarfile import logging @@ -22,139 +21,51 @@ from sagemaker.utils import _tmpdir, custom_extractall_tarfile from sagemaker.s3 import S3Downloader -from sagemaker.djl_inference import DJLModel -from sagemaker.djl_inference.model import _read_existing_serving_properties from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage -_SERVING_PROPERTIES_FILE = "serving.properties" -_ENTRY_POINT_SCRIPT = "inference.py" _SETTING_PROPERTY_STMT = "Setting property: %s to %s" logger = logging.getLogger(__name__) -def _has_serving_properties_file(code_dir: Path) -> bool: - """Check for existing serving properties in the directory""" - return code_dir.joinpath(_SERVING_PROPERTIES_FILE).is_file() - - -def _move_to_code_dir(js_model_dir: str, code_dir: Path): - """Move DJL Jumpstart resources from model to code_dir""" - js_model_resources = Path(js_model_dir).joinpath("model") - for resource in js_model_resources.glob("*"): - try: - shutil.move(resource, code_dir) - except shutil.Error as e: - if "already exists" in str(e): - continue - - -def _extract_js_resource(js_model_dir: str, js_id: str): +def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str): """Uncompress the jumpstart resource""" tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz") with tarfile.open(str(tmp_sourcedir)) as resources: - custom_extractall_tarfile(resources, js_model_dir) + custom_extractall_tarfile(resources, code_dir) -def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path): +def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> tuple: """Copy the associated JumpStart Resource into the code directory""" logger.info("Downloading JumpStart artifacts from S3...") s3_downloader = S3Downloader() - invalid_model_data_format = False - with _tmpdir(directory=str(code_dir)) as js_model_dir: - if isinstance(model_data, str): - if model_data.endswith(".tar.gz"): - logger.info("Uncompressing JumpStart artifacts for faster loading...") - s3_downloader.download(model_data, js_model_dir) - _extract_js_resource(js_model_dir, js_id) - else: - logger.info("Copying uncompressed JumpStart artifacts...") + if isinstance(model_data, str): + if model_data.endswith(".tar.gz"): + logger.info("Uncompressing JumpStart artifacts for faster loading...") + with _tmpdir(directory=str(code_dir)) as js_model_dir: s3_downloader.download(model_data, js_model_dir) - elif ( - isinstance(model_data, dict) - and model_data.get("S3DataSource") - and model_data.get("S3DataSource").get("S3Uri") - ): - logger.info("Copying uncompressed JumpStart artifacts...") - s3_downloader.download(model_data.get("S3DataSource").get("S3Uri"), js_model_dir) + _extract_js_resource(js_model_dir, code_dir, js_id) else: - invalid_model_data_format = True - if not invalid_model_data_format: - _move_to_code_dir(js_model_dir, code_dir) - - if invalid_model_data_format: + logger.info("Copying uncompressed JumpStart artifacts...") + s3_downloader.download(model_data, code_dir) + elif ( + isinstance(model_data, dict) + and model_data.get("S3DataSource") + and model_data.get("S3DataSource").get("S3Uri") + ): + logger.info("Copying uncompressed JumpStart artifacts...") + s3_downloader.download(model_data.get("S3DataSource").get("S3Uri"), code_dir) + else: raise ValueError("JumpStart model data compression format is unsupported: %s", model_data) - existing_properties = _read_existing_serving_properties(code_dir) config_json_file = code_dir.joinpath("config.json") - hf_model_config = None if config_json_file.is_file(): with open(str(config_json_file)) as config_json: hf_model_config = json.load(config_json) - return (existing_properties, hf_model_config, True) - - -def _generate_properties_file( - model: DJLModel, code_dir: Path, overwrite_props_from_file: bool, manual_set_props: dict -): - """Construct serving properties file taking into account of overrides or manual specs""" - if _has_serving_properties_file(code_dir): - existing_properties = _read_existing_serving_properties(code_dir) - else: - existing_properties = {} - - serving_properties_dict = model.generate_serving_properties() - serving_properties_file = code_dir.joinpath(_SERVING_PROPERTIES_FILE) - - with open(serving_properties_file, mode="w+") as file: - covered_keys = set() - - if manual_set_props: - for key, value in manual_set_props.items(): - logger.info(_SETTING_PROPERTY_STMT, key, value.strip()) - covered_keys.add(key) - file.write(f"{key}={value}") - - for key, value in serving_properties_dict.items(): - if not overwrite_props_from_file: - logger.info(_SETTING_PROPERTY_STMT, key, value) - file.write(f"{key}={value}\n") - else: - existing_property = existing_properties.get(key) - covered_keys.add(key) - if not existing_property: - logger.info(_SETTING_PROPERTY_STMT, key, value) - file.write(f"{key}={value}\n") - else: - logger.info(_SETTING_PROPERTY_STMT, key, existing_property.strip()) - file.write(f"{key}={existing_property}") - - if overwrite_props_from_file: - # for addition provided properties - for key, value in existing_properties.items(): - if key not in covered_keys: - logger.info(_SETTING_PROPERTY_STMT, key, value.strip()) - file.write(f"{key}={value}") - - -def _store_share_libs(model_path: Path, shared_libs): - """Placeholder Docstring""" - shared_libs_dir = model_path.joinpath("shared_libs") - shared_libs_dir.mkdir(exist_ok=True) - for shared_lib in shared_libs: - shutil.copy2(Path(shared_lib), shared_libs_dir) - - -def _copy_inference_script(code_dir): - """Placeholder Docstring""" - if code_dir.joinpath("inference.py").is_file(): - return - - inference_file = Path(__file__).parent.joinpath(_ENTRY_POINT_SCRIPT) - shutil.copy2(inference_file, code_dir) + return (hf_model_config, True) def _create_dir_structure(model_path: str) -> tuple: @@ -174,36 +85,6 @@ def _create_dir_structure(model_path: str) -> tuple: return (model_path, code_dir) -def prepare_for_djl_serving( - model_path: str, - model: DJLModel, - shared_libs: List[str] = None, - dependencies: str = None, - overwrite_props_from_file: bool = True, - manual_set_props: dict = None, -): - """Prepare serving when a HF model id is given - - Args:to - model_path (str) : Argument - model (DJLModel) : Argument - shared_libs (List[]) : Argument - dependencies (str) : Argument - - Returns: - ( str ) : - - """ - model_path, code_dir = _create_dir_structure(model_path) - - if shared_libs: - _store_share_libs(model_path, shared_libs) - - _copy_inference_script(code_dir) - - _generate_properties_file(model, code_dir, overwrite_props_from_file, manual_set_props) - - def prepare_djl_js_resources( model_path: str, js_id: str, diff --git a/src/sagemaker/serve/model_server/djl_serving/server.py b/src/sagemaker/serve/model_server/djl_serving/server.py index 8b152e5b81..4ba7dd227d 100644 --- a/src/sagemaker/serve/model_server/djl_serving/server.py +++ b/src/sagemaker/serve/model_server/djl_serving/server.py @@ -12,6 +12,7 @@ from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader from sagemaker.local.utils import get_docker_host +from sagemaker.serve.utils.optimize_utils import _is_s3_uri logger = logging.getLogger(__name__) MODE_DIR_BINDING = "/opt/ml/model/" @@ -19,6 +20,7 @@ _DEFAULT_ENV_VARS = { "SERVING_OPTS": "-Dai.djl.logging.level=debug", "TRANSFORMERS_CACHE": "/opt/ml/model/", + "HF_HOME": "/opt/ml/model/", "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", } @@ -90,39 +92,48 @@ def _upload_djl_artifacts( s3_model_data_url: str = None, image: str = None, env_vars: dict = None, + should_upload_artifacts: bool = False, ): """Placeholder docstring""" - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - code_dir = Path(model_path).joinpath("code") + code_dir = Path(model_path).joinpath("code") - s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") - logger.debug("Uploading DJL Model Resources uncompressed to: %s", s3_location) + logger.debug("Uploading DJL Model Resources uncompressed to: %s", s3_location) - model_data_url = S3Uploader.upload( - str(code_dir), - s3_location, - None, - sagemaker_session, - ) + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) - model_data = { - "S3DataSource": { - "CompressionType": "None", - "S3DataType": "S3Prefix", - "S3Uri": model_data_url + "/", + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } } - } + if model_data_url + else None + ) return (model_data, _update_env_vars(env_vars)) diff --git a/src/sagemaker/serve/model_server/djl_serving/utils.py b/src/sagemaker/serve/model_server/djl_serving/utils.py index 03719542d2..93d16001df 100644 --- a/src/sagemaker/serve/model_server/djl_serving/utils.py +++ b/src/sagemaker/serve/model_server/djl_serving/utils.py @@ -1,12 +1,8 @@ """DJL ModelBuilder Utils""" from __future__ import absolute_import -from urllib.error import HTTPError import math import logging -from sagemaker.serve.utils.types import _DjlEngine -from sagemaker.djl_inference import defaults -from sagemaker.djl_inference.model import _get_model_config_properties_from_hf from sagemaker.serve.utils.local_hardware import _get_available_gpus from sagemaker.serve.builder.schema_builder import SchemaBuilder @@ -17,50 +13,6 @@ TOKENS_PER_WORD = 0.75 -def _auto_detect_engine(model_id: str, hf_hub_token: str) -> tuple: - """Placeholder docstring""" - try: - hf_model_config = _get_model_config_properties_from_hf(model_id, hf_hub_token) - model_type = hf_model_config.get("model_type") - - if len(model_type) < 1: - logger.warning( - "Unable to detect the model architecture from provided model_id %s.\ - Defaulting to HuggingFaceAccelerate." - % model_id - ) - engine = _DjlEngine.HUGGINGFACE_ACCELERATE - elif model_type in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES: - logger.info("Model architecture %s is recommended to be run on DeepSpeed." % model_type) - engine = _DjlEngine.DEEPSPEED - elif model_type in defaults.FASTER_TRANSFORMER_RECOMMENDED_ARCHITECTURES: - logger.info( - "Model architecture %s is recommended to be run on FasterTransformer." % model_type - ) - engine = _DjlEngine.FASTER_TRANSFORMER - else: - logger.info( - "Model architecture %s does not have a recommended engine. Defaulting to HuggingFaceAccelerate." - % model_type - ) - engine = _DjlEngine.HUGGINGFACE_ACCELERATE - except HTTPError as e: - raise ValueError( - "The provided HuggingFace Model ID could not be accessed from HuggingFace Hub. %s", - str(e), - ) - except ValueError as e: - raise e - except Exception as e: - logger.warning( - "Unable to detect the model's architecture: %s. Defaulting to HuggingFaceAccelerate." - % str(e) - ) - engine = _DjlEngine.HUGGINGFACE_ACCELERATE - - return (engine, hf_model_config) - - def _get_default_tensor_parallel_degree(hf_model_config: dict, gpu_count: int = None) -> int: """Placeholder docstring""" available_gpus = _get_available_gpus() @@ -89,7 +41,7 @@ def _get_default_tensor_parallel_degree(hf_model_config: dict, gpu_count: int = def _get_default_data_type() -> tuple: """Placeholder docstring""" - return "fp16" + return "bf16" def _get_default_batch_size() -> int: @@ -144,22 +96,23 @@ def _get_default_max_tokens(sample_input, sample_output) -> tuple: return (max_total_tokens, max_new_tokens) -def _set_serve_properties(hf_model_config: dict, schema_builder: SchemaBuilder) -> tuple: +def _get_default_djl_configurations( + model_id: str, hf_model_config: dict, schema_builder: SchemaBuilder +) -> tuple: """Placeholder docstring""" default_tensor_parallel_degree = _get_default_tensor_parallel_degree(hf_model_config) + if default_tensor_parallel_degree is None: + default_tensor_parallel_degree = "max" default_data_type = _get_default_data_type() - default_batch_size = _get_default_batch_size() default_max_tokens, default_max_new_tokens = _get_default_max_tokens( schema_builder.sample_input, schema_builder.sample_output ) - return ( - default_tensor_parallel_degree, - default_data_type, - default_batch_size, - default_max_tokens, - default_max_new_tokens, - ) + env = { + "TENSOR_PARALLEL_DEGREE": str(default_tensor_parallel_degree), + "OPTION_DTYPE": default_data_type, + } + return (env, default_max_new_tokens) def _get_admissible_tensor_parallel_degrees(hf_model_config: dict) -> int: diff --git a/src/sagemaker/serve/model_server/in_process_model_server/app.py b/src/sagemaker/serve/model_server/in_process_model_server/app.py new file mode 100644 index 0000000000..18fe63a5fc --- /dev/null +++ b/src/sagemaker/serve/model_server/in_process_model_server/app.py @@ -0,0 +1,150 @@ +"""FastAPI requests""" + +from __future__ import absolute_import + +import asyncio +import io +import logging +import threading +import torch +from typing import Optional, Type + +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.builder.schema_builder import SchemaBuilder + +logger = logging.getLogger(__name__) + + +try: + import uvicorn +except ImportError: + logger.error("Unable to import uvicorn, check if uvicorn is installed.") + + +try: + from fastapi import FastAPI, Request, APIRouter +except ImportError: + logger.error("Unable to import fastapi, check if fastapi is installed.") + + +class InProcessServer: + """Generic In-Process Server for Serving Models using InferenceSpec""" + + def __init__( + self, + model: Optional[str] = None, + inference_spec: Optional[InferenceSpec] = None, + schema_builder: Type[SchemaBuilder] = None, + task: Optional[str] = None, + ): + self._thread = None + self._loop = None + self._stop_event = asyncio.Event() + self._shutdown_event = threading.Event() + self._router = APIRouter() + self._task = task + self.server = None + self.port = None + self.host = None + self.model = model + self.inference_spec = inference_spec + self.schema_builder = schema_builder + + if self.inference_spec: + # Use inference_spec to load the model + self._load_model = self.inference_spec.load(model_dir=None) + elif isinstance(self.model, str): + try: + # Use transformers pipeline to load the model + try: + from transformers import pipeline, Pipeline + except ImportError: + logger.error( + "Unable to import transformers, check if transformers is installed." + ) + + device = 0 if torch.cuda.is_available() else -1 + + self._load_model = pipeline(task, model=self.model, device=device) + except Exception: + logger.info("Falling back to SentenceTransformer for model loading.") + try: + from sentence_transformers import SentenceTransformer + except ImportError: + logger.error( + "Unable to import sentence-transformers, check if sentence-transformers is installed." + ) + + self._load_model = SentenceTransformer(self.model) + else: + raise ValueError("Either inference_spec or model must be provided.") + + @self._router.post("/invoke") + async def invoke(request: Request): + """Generate text based on the provided prompt""" + + request_header = request.headers + request_body = await request.body() + content_type = request_header.get("Content-Type", None) + input_data = schema_builder.input_deserializer.deserialize( + io.BytesIO(request_body), content_type[0] + ) + logger.debug(f"Received request: {input_data}") + if self.inference_spec: + response = self.inference_spec.invoke(input_data, self._load_model) + else: + input_data = input_data["inputs"] if "inputs" in input_data else input_data + if isinstance(self._load_model, Pipeline): + response = self._load_model(input_data, max_length=30, num_return_sequences=1) + else: + embeddings = self._load_model.encode(input_data, normalize_embeddings=True) + response = {"embeddings": embeddings.tolist()} + return response + + self._create_server() + + def _create_server(self): + """Placeholder docstring""" + app = FastAPI() + app.include_router(self._router) + + config = uvicorn.Config( + app, + host="127.0.0.1", + port=9007, + log_level="info", + loop="asyncio", + reload=True, + use_colors=True, + ) + + self.server = uvicorn.Server(config) + self.host = config.host + self.port = config.port + + def start_server(self): + """Starts the uvicorn server.""" + if not (self._thread and self._thread.is_alive()): + logger.info("Waiting for a connection...") + self._thread = threading.Thread(target=self._start_run_async_in_thread, daemon=True) + self._thread.start() + + def stop_server(self): + """Stops the Uvicorn server by setting the shutdown event.""" + if self._thread and self._thread.is_alive(): + logger.info("Shutting down the server...") + self._shutdown_event.set() + self.server.handle_exit(sig=0, frame=None) + self._thread.join() + + logger.info("Server shutdown complete.") + + def _start_run_async_in_thread(self): + """Placeholder docstring""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._serve()) + + async def _serve(self): + """Placeholder docstring""" + await self.server.serve() diff --git a/src/sagemaker/serve/model_server/in_process_model_server/in_process_server.py b/src/sagemaker/serve/model_server/in_process_model_server/in_process_server.py new file mode 100644 index 0000000000..d391fe50a0 --- /dev/null +++ b/src/sagemaker/serve/model_server/in_process_model_server/in_process_server.py @@ -0,0 +1,60 @@ +"""Module for In_process Serving""" + +from __future__ import absolute_import + +import requests +import logging +from sagemaker.serve.utils.exceptions import LocalModelInvocationException +from sagemaker.base_predictor import PredictorBase + +logger = logging.getLogger(__name__) + + +class InProcessServing: + """In Process Mode server instance""" + + def _start_serving(self): + """Initializes the start of the server""" + from sagemaker.serve.model_server.in_process_model_server.app import InProcessServer + + self.server = InProcessServer( + inference_spec=self.inference_spec, model=self.model, schema_builder=self.schema_builder + ) + self.server.start_server() + + def _stop_serving(self): + """Stops the server""" + self.server.stop_server() + + def _invoke_serving(self, request: object, content_type: str, accept: str): + """Placeholder docstring""" + try: + response = requests.post( + f"http://{self.server.host}:{self.server.port}/invoke", + data=request, + headers={"Content-Type": content_type, "Accept": accept}, + timeout=600, + ) + response.raise_for_status() + + return response.content + except Exception as e: + if "Connection refused" in str(e): + raise Exception( + "Unable to send request to the local server: Connection refused." + ) from e + raise Exception("Unable to send request to the local container server %s", str(e)) + + def _deep_ping(self, predictor: PredictorBase): + """Sends a deep ping to ensure prediction""" + healthy = False + response = None + try: + response = predictor.predict(self.schema_builder.sample_input) + healthy = response is not None + # pylint: disable=broad-except + except Exception as e: + if "422 Client Error: Unprocessable Entity for url" in str(e): + raise LocalModelInvocationException(str(e)) + + return healthy, response diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py new file mode 100644 index 0000000000..9361765da0 --- /dev/null +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -0,0 +1,128 @@ +"""This module is for SageMaker inference.py.""" + +from __future__ import absolute_import +import os +import io +import cloudpickle +import shutil +import platform +from pathlib import Path +from functools import partial +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.validations.check_integrity import perform_integrity_check +import logging + +logger = logging.getLogger(__name__) + +inference_spec = None +schema_builder = None +SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs") +SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl") +METADATA_PATH = Path(__file__).parent.joinpath("metadata.json") + + +def model_fn(model_dir, context=None): + """Overrides default method for loading a model""" + shared_libs_path = Path(model_dir + "/shared_libs") + + if shared_libs_path.exists(): + # before importing, place dynamic linked libraries in shared lib path + shutil.copytree(shared_libs_path, "/lib", dirs_exist_ok=True) + + serve_path = Path(__file__).parent.joinpath("serve.pkl") + with open(str(serve_path), mode="rb") as file: + global inference_spec, schema_builder + obj = cloudpickle.load(file) + if isinstance(obj[0], InferenceSpec): + inference_spec, schema_builder = obj + + if inference_spec: + return partial(inference_spec.invoke, model=inference_spec.load(model_dir)) + + +def input_fn(input_data, content_type, context=None): + """Deserializes the bytes that were received from the model server""" + try: + if hasattr(schema_builder, "custom_input_translator"): + deserialized_data = schema_builder.custom_input_translator.deserialize( + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type, + ) + else: + deserialized_data = schema_builder.input_deserializer.deserialize( + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type[0], + ) + + # Check if preprocess method is defined and call it + if hasattr(inference_spec, "preprocess"): + return inference_spec.preprocess(deserialized_data) + + return deserialized_data + except Exception as e: + logger.error("Encountered error: %s in deserialize_response." % e) + raise Exception("Encountered error in deserialize_request.") from e + + +def predict_fn(input_data, predict_callable, context=None): + """Invokes the model that is taken in by model server""" + return predict_callable(input_data) + + +def output_fn(predictions, accept_type, context=None): + """Prediction is serialized to bytes and sent back to the customer""" + try: + if hasattr(inference_spec, "postprocess"): + predictions = inference_spec.postprocess(predictions) + if hasattr(schema_builder, "custom_output_translator"): + return schema_builder.custom_output_translator.serialize(predictions, accept_type) + else: + return schema_builder.output_serializer.serialize(predictions) + except Exception as e: + logger.error("Encountered error: %s in serialize_response." % e) + raise Exception("Encountered error in serialize_response.") from e + + +def _run_preflight_diagnostics(): + _py_vs_parity_check() + _pickle_file_integrity_check() + + +def _py_vs_parity_check(): + container_py_vs = platform.python_version() + local_py_vs = os.getenv("LOCAL_PYTHON") + + if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]: + logger.warning( + f"The local python version {local_py_vs} differs from the python version " + f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior" + ) + + +def _pickle_file_integrity_check(): + with open(SERVE_PATH, "rb") as f: + buffer = f.read() + + perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH) + + +# on import, execute +_run_preflight_diagnostics() diff --git a/src/sagemaker/serve/model_server/multi_model_server/prepare.py b/src/sagemaker/serve/model_server/multi_model_server/prepare.py index 7a16cc0a43..e3abc70dd6 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/prepare.py +++ b/src/sagemaker/serve/model_server/multi_model_server/prepare.py @@ -14,10 +14,23 @@ from __future__ import absolute_import import logging -from pathlib import Path +from sagemaker.serve.model_server.tgi.prepare import _copy_jumpstart_artifacts from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage +from pathlib import Path +import shutil +from typing import List + +from sagemaker.session import Session +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.detector.dependency_manager import capture_dependencies +from sagemaker.serve.validations.check_integrity import ( + generate_secret_key, + compute_hash, +) +from sagemaker.remote_function.core.serialization import _MetaData + logger = logging.getLogger(__name__) @@ -36,3 +49,82 @@ def _create_dir_structure(model_path: str) -> tuple: _check_docker_disk_usage() return model_path, code_dir + + +def prepare_mms_js_resources( + model_path: str, + js_id: str, + shared_libs: List[str] = None, + dependencies: str = None, + model_data: str = None, +) -> tuple: + """Prepare serving when a JumpStart model id is given + + Args: + model_path (str) : Argument + js_id (str): Argument + shared_libs (List[]) : Argument + dependencies (str) : Argument + model_data (str) : Argument + + Returns: + ( str ) : + + """ + model_path, code_dir = _create_dir_structure(model_path) + + return _copy_jumpstart_artifacts(model_data, js_id, code_dir) + + +def prepare_for_mms( + model_path: str, + shared_libs: List[str], + dependencies: dict, + session: Session, + image_uri: str, + inference_spec: InferenceSpec = None, +) -> str: + """Prepares for InferenceSpec using model_path, writes inference.py, \ + and captures dependencies to generate secret_key. + + Args:to + model_path (str) : Argument + shared_libs (List[]) : Argument + dependencies (dict) : Argument + session (Session) : Argument + inference_spec (InferenceSpec, optional) : Argument + (default is None) + Returns: + ( str ) : secret_key + """ + model_path = Path(model_path) + if not model_path.exists(): + model_path.mkdir() + elif not model_path.is_dir(): + raise Exception("model_dir is not a valid directory") + + if inference_spec: + inference_spec.prepare(str(model_path)) + + code_dir = model_path.joinpath("code") + code_dir.mkdir(exist_ok=True) + + shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir) + + logger.info("Finished writing inference.py to code directory") + + shared_libs_dir = model_path.joinpath("shared_libs") + shared_libs_dir.mkdir(exist_ok=True) + for shared_lib in shared_libs: + shutil.copy2(Path(shared_lib), shared_libs_dir) + + capture_dependencies(dependencies=dependencies, work_dir=code_dir) + + secret_key = generate_secret_key() + with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: + buffer = f.read() + hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: + metadata.write(_MetaData(hash_value).to_json()) + + return secret_key diff --git a/src/sagemaker/serve/model_server/multi_model_server/server.py b/src/sagemaker/serve/model_server/multi_model_server/server.py index b78e01f5c3..2fab727c05 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/server.py +++ b/src/sagemaker/serve/model_server/multi_model_server/server.py @@ -4,13 +4,16 @@ import requests import logging +import platform from pathlib import Path + from sagemaker import Session, fw_utils from sagemaker.serve.utils.exceptions import LocalModelInvocationException from sagemaker.base_predictor import PredictorBase from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader from sagemaker.local.utils import get_docker_host +from sagemaker.serve.utils.optimize_utils import _is_s3_uri MODE_DIR_BINDING = "/opt/ml/model/" _DEFAULT_ENV_VARS = {} @@ -29,7 +32,18 @@ def _start_serving( secret_key: str, env_vars: dict, ): - """Placeholder docstring""" + """Initializes the start of the server""" + env = { + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "LOCAL_PYTHON": platform.python_version(), + } + if env_vars: + env_vars.update(env) + else: + env_vars = env + self.container = client.containers.run( image, "serve", @@ -42,11 +56,11 @@ def _start_serving( "mode": "rw", }, }, - environment=_update_env_vars(env_vars), + environment=env_vars, ) def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str): - """Placeholder docstring""" + """Invokes MMS server by hitting the docker host""" try: response = requests.post( f"http://{get_docker_host()}:8080/invocations", @@ -60,7 +74,7 @@ def _invoke_multi_model_server_serving(self, request: object, content_type: str, raise Exception("Unable to send request to the local container server") from e def _multi_model_server_deep_ping(self, predictor: PredictorBase): - """Placeholder docstring""" + """Deep ping in order to ensure prediction""" response = None try: response = predictor.predict(self.schema_builder.sample_input) @@ -80,42 +94,63 @@ class SageMakerMultiModelServer: def _upload_server_artifacts( self, model_path: str, + secret_key: str, sagemaker_session: Session, s3_model_data_url: str = None, image: str = None, env_vars: dict = None, + should_upload_artifacts: bool = False, ): - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + code_dir = Path(model_path).joinpath("code") - code_dir = Path(model_path).joinpath("code") + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") - s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + logger.debug("Uploading Multi Model Server Resources uncompressed to: %s", s3_location) - logger.debug("Uploading Multi Model Server Resources uncompressed to: %s", s3_location) + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) - model_data_url = S3Uploader.upload( - str(code_dir), - s3_location, - None, - sagemaker_session, + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } + } + if model_data_url + else None ) - model_data = { - "S3DataSource": { - "CompressionType": "None", - "S3DataType": "S3Prefix", - "S3Uri": model_data_url + "/", + if secret_key: + env_vars = { + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "SAGEMAKER_REGION": sagemaker_session.boto_region_name, + "SAGEMAKER_CONTAINER_LOG_LEVEL": "10", + "LOCAL_PYTHON": platform.python_version(), } - } + return model_data, _update_env_vars(env_vars) diff --git a/src/sagemaker/serve/model_server/smd/custom_execution_inference.py b/src/sagemaker/serve/model_server/smd/custom_execution_inference.py new file mode 100644 index 0000000000..f53677fc69 --- /dev/null +++ b/src/sagemaker/serve/model_server/smd/custom_execution_inference.py @@ -0,0 +1,72 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module is for SageMaker inference.py.""" + +from __future__ import absolute_import +import asyncio +import os +import platform +import cloudpickle +import logging +from pathlib import Path +from sagemaker.serve.validations.check_integrity import perform_integrity_check + +logger = LOGGER = logging.getLogger("sagemaker") + + +def initialize_custom_orchestrator(): + """Initializes the custom orchestrator.""" + code_dir = os.getenv("SAGEMAKER_INFERENCE_CODE_DIRECTORY", None) + serve_path = Path(code_dir).joinpath("serve.pkl") + with open(str(serve_path), mode="rb") as pkl_file: + return cloudpickle.load(pkl_file) + + +def _run_preflight_diagnostics(): + _py_vs_parity_check() + _pickle_file_integrity_check() + + +def _py_vs_parity_check(): + container_py_vs = platform.python_version() + local_py_vs = os.getenv("LOCAL_PYTHON") + + if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]: + logger.warning( + f"The local python version {local_py_vs} differs from the python version " + f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior" + ) + + +def _pickle_file_integrity_check(): + with open("/opt/ml/model/code/serve.pkl", "rb") as f: + buffer = f.read() + + metadata_path = Path("/opt/ml/model/code/metadata.json") + perform_integrity_check(buffer=buffer, metadata_path=metadata_path) + + +_run_preflight_diagnostics() +custom_orchestrator, _ = initialize_custom_orchestrator() + + +async def handler(request): + """Custom service entry point function. + + :param request: raw input from request + :return: outputs to be send back to client + """ + if asyncio.iscoroutinefunction(custom_orchestrator.handle): + return await custom_orchestrator.handle(request.body) + else: + return custom_orchestrator.handle(request.body) diff --git a/src/sagemaker/serve/model_server/smd/prepare.py b/src/sagemaker/serve/model_server/smd/prepare.py new file mode 100644 index 0000000000..6461e4023f --- /dev/null +++ b/src/sagemaker/serve/model_server/smd/prepare.py @@ -0,0 +1,74 @@ +"""Summary of MyModule. + +Extended discussion of my module. +""" + +from __future__ import absolute_import +import os +from pathlib import Path +import shutil +from typing import List + +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.detector.dependency_manager import capture_dependencies +from sagemaker.serve.validations.check_integrity import ( + generate_secret_key, + compute_hash, +) +from sagemaker.remote_function.core.serialization import _MetaData +from sagemaker.serve.spec.inference_base import CustomOrchestrator, AsyncCustomOrchestrator + + +def prepare_for_smd( + model_path: str, + shared_libs: List[str], + dependencies: dict, + inference_spec: InferenceSpec = None, +) -> str: + """Prepares artifacts for SageMaker model deployment. + + Args:to + model_path (str) : Argument + shared_libs (List[]) : Argument + dependencies (dict) : Argument + inference_spec (InferenceSpec, optional) : Argument + (default is None) + + Returns: + ( str ) : + + """ + model_path = Path(model_path) + if not model_path.exists(): + model_path.mkdir() + elif not model_path.is_dir(): + raise Exception("model_dir is not a valid directory") + + if inference_spec and isinstance(inference_spec, InferenceSpec): + inference_spec.prepare(str(model_path)) + + code_dir = model_path.joinpath("code") + code_dir.mkdir(exist_ok=True) + + if inference_spec and isinstance(inference_spec, (CustomOrchestrator, AsyncCustomOrchestrator)): + shutil.copy2(Path(__file__).parent.joinpath("custom_execution_inference.py"), code_dir) + os.rename( + str(code_dir.joinpath("custom_execution_inference.py")), + str(code_dir.joinpath("inference.py")), + ) + + shared_libs_dir = model_path.joinpath("shared_libs") + shared_libs_dir.mkdir(exist_ok=True) + for shared_lib in shared_libs: + shutil.copy2(Path(shared_lib), shared_libs_dir) + + capture_dependencies(dependencies=dependencies, work_dir=code_dir) + + secret_key = generate_secret_key() + with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: + buffer = f.read() + hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: + metadata.write(_MetaData(hash_value).to_json()) + + return secret_key diff --git a/src/sagemaker/serve/model_server/smd/server.py b/src/sagemaker/serve/model_server/smd/server.py new file mode 100644 index 0000000000..c700c39727 --- /dev/null +++ b/src/sagemaker/serve/model_server/smd/server.py @@ -0,0 +1,59 @@ +"""Module for SMD Server""" + +from __future__ import absolute_import + +import logging +import platform +from sagemaker.serve.utils.optimize_utils import _is_s3_uri +from sagemaker.session import Session +from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url +from sagemaker import fw_utils +from sagemaker.serve.utils.uploader import upload + +logger = logging.getLogger(__name__) + + +class SageMakerSmdServer: + """Placeholder docstring""" + + def _upload_smd_artifacts( + self, + model_path: str, + sagemaker_session: Session, + secret_key: str, + s3_model_data_url: str = None, + image: str = None, + should_upload_artifacts: bool = False, + ): + """Tar the model artifact and upload to S3 bucket, then prepare for the environment variables""" + s3_upload_path = None + if _is_s3_uri(model_path): + s3_upload_path = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) + + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", + bucket, + code_key_prefix, + ) + s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) + + env_vars = { + "SAGEMAKER_INFERENCE_CODE_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_INFERENCE_CODE": "inference.handler", + "SAGEMAKER_REGION": sagemaker_session.boto_region_name, + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "LOCAL_PYTHON": platform.python_version(), + } + return s3_upload_path, env_vars diff --git a/src/sagemaker/serve/model_server/tei/__init__.py b/src/sagemaker/serve/model_server/tei/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/serve/model_server/tei/server.py b/src/sagemaker/serve/model_server/tei/server.py new file mode 100644 index 0000000000..94265e224f --- /dev/null +++ b/src/sagemaker/serve/model_server/tei/server.py @@ -0,0 +1,171 @@ +"""Module for Local TEI Serving""" + +from __future__ import absolute_import + +import requests +import logging +from pathlib import Path +from docker.types import DeviceRequest +from sagemaker import Session, fw_utils +from sagemaker.serve.utils.exceptions import LocalModelInvocationException +from sagemaker.base_predictor import PredictorBase +from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join +from sagemaker.s3 import S3Uploader +from sagemaker.local.utils import get_docker_host +from sagemaker.serve.utils.optimize_utils import _is_s3_uri + +MODE_DIR_BINDING = "/opt/ml/model/" +_SHM_SIZE = "2G" +_DEFAULT_ENV_VARS = { + "HF_HOME": "/opt/ml/model/", + "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", +} + +logger = logging.getLogger(__name__) + + +class LocalTeiServing: + """LocalTeiServing class""" + + def _start_tei_serving( + self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict + ): + """Starts a local tei serving container. + + Args: + client: Docker client + image: Image to use + model_path: Path to the model + secret_key: Secret key to use for authentication + env_vars: Environment variables to set + """ + if env_vars and secret_key: + env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key + + self.container = client.containers.run( + image, + shm_size=_SHM_SIZE, + device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])], + network_mode="host", + detach=True, + auto_remove=True, + volumes={ + Path(model_path).joinpath("code"): { + "bind": MODE_DIR_BINDING, + "mode": "rw", + }, + }, + environment=_update_env_vars(env_vars), + ) + + def _invoke_tei_serving(self, request: object, content_type: str, accept: str): + """Invokes a local tei serving container. + + Args: + request: Request to send + content_type: Content type to use + accept: Accept to use + """ + try: + response = requests.post( + f"http://{get_docker_host()}:8080/invocations", + data=request, + headers={"Content-Type": content_type, "Accept": accept}, + timeout=600, + ) + response.raise_for_status() + return response.content + except Exception as e: + raise Exception("Unable to send request to the local container server") from e + + def _tei_deep_ping(self, predictor: PredictorBase): + """Checks if the local tei serving container is up and running. + + If the container is not up and running, it will raise an exception. + """ + response = None + try: + response = predictor.predict(self.schema_builder.sample_input) + return (True, response) + # pylint: disable=broad-except + except Exception as e: + if "422 Client Error: Unprocessable Entity for url" in str(e): + raise LocalModelInvocationException(str(e)) + return (False, response) + + return (True, response) + + +class SageMakerTeiServing: + """SageMakerTeiServing class""" + + def _upload_tei_artifacts( + self, + model_path: str, + sagemaker_session: Session, + s3_model_data_url: str = None, + image: str = None, + env_vars: dict = None, + should_upload_artifacts: bool = False, + ): + """Uploads the model artifacts to S3. + + Args: + model_path: Path to the model + sagemaker_session: SageMaker session + s3_model_data_url: S3 model data URL + image: Image to use + env_vars: Environment variables to set + model_data_s3_path: S3 path to model data + should_upload_artifacts: Whether to upload artifacts + """ + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) + + code_dir = Path(model_path).joinpath("code") + + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + + logger.debug("Uploading TEI Model Resources uncompressed to: %s", s3_location) + + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) + + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } + } + if model_data_url + else None + ) + + return (model_data, _update_env_vars(env_vars)) + + +def _update_env_vars(env_vars: dict) -> dict: + """Placeholder docstring""" + updated_env_vars = {} + updated_env_vars.update(_DEFAULT_ENV_VARS) + if env_vars: + updated_env_vars.update(env_vars) + return updated_env_vars diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/__init__.py b/src/sagemaker/serve/model_server/tensorflow_serving/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/inference.py b/src/sagemaker/serve/model_server/tensorflow_serving/inference.py new file mode 100644 index 0000000000..928278e3c6 --- /dev/null +++ b/src/sagemaker/serve/model_server/tensorflow_serving/inference.py @@ -0,0 +1,147 @@ +"""This module is for SageMaker inference.py.""" + +from __future__ import absolute_import +import os +import io +import json +import cloudpickle +import shutil +import platform +from pathlib import Path +from sagemaker.serve.validations.check_integrity import perform_integrity_check +import logging + +logger = logging.getLogger(__name__) + +schema_builder = None +SHARED_LIBS_DIR = Path(__file__).parent.parent.joinpath("shared_libs") +SERVE_PATH = Path(__file__).parent.joinpath("serve.pkl") +METADATA_PATH = Path(__file__).parent.joinpath("metadata.json") + + +def input_handler(data, context): + """Pre-process request input before it is sent to TensorFlow Serving REST API + + Args: + data (obj): the request data, in format of dict or string + context (Context): an object containing request and configuration details + Returns: + (dict): a JSON-serializable dict that contains request body and headers + """ + read_data = data.read() + deserialized_data = None + try: + if hasattr(schema_builder, "custom_input_translator"): + deserialized_data = schema_builder.custom_input_translator.deserialize( + io.BytesIO(read_data), context.request_content_type + ) + else: + deserialized_data = schema_builder.input_deserializer.deserialize( + io.BytesIO(read_data), context.request_content_type + ) + except Exception as e: + logger.error("Encountered error: %s in deserialize_request." % e) + raise Exception("Encountered error in deserialize_request.") from e + + try: + return json.dumps({"instances": _convert_for_serialization(deserialized_data)}) + except Exception as e: + logger.error( + "Encountered error: %s in deserialize_request. " + "Deserialized data is not json serializable. " % e + ) + raise Exception("Encountered error in deserialize_request.") from e + + +def output_handler(data, context): + """Post-process TensorFlow Serving output before it is returned to the client. + + Args: + data (obj): the TensorFlow serving response + context (Context): an object containing request and configuration details + Returns: + (bytes, string): data to return to client, response content type + """ + if data.status_code != 200: + raise ValueError(data.content.decode("utf-8")) + + response_content_type = context.accept_header + prediction = data.content + try: + prediction_dict = json.loads(prediction.decode("utf-8")) + if hasattr(schema_builder, "custom_output_translator"): + return ( + schema_builder.custom_output_translator.serialize( + prediction_dict["predictions"], response_content_type + ), + response_content_type, + ) + else: + return schema_builder.output_serializer.serialize(prediction), response_content_type + except Exception as e: + logger.error("Encountered error: %s in serialize_response." % e) + raise Exception("Encountered error in serialize_response.") from e + + +def _run_preflight_diagnostics(): + _py_vs_parity_check() + _pickle_file_integrity_check() + + +def _py_vs_parity_check(): + container_py_vs = platform.python_version() + local_py_vs = os.getenv("LOCAL_PYTHON") + + if not local_py_vs or container_py_vs.split(".")[1] != local_py_vs.split(".")[1]: + logger.warning( + f"The local python version {local_py_vs} differs from the python version " + f"{container_py_vs} on the container. Please align the two to avoid unexpected behavior" + ) + + +def _pickle_file_integrity_check(): + with open(SERVE_PATH, "rb") as f: + buffer = f.read() + + perform_integrity_check(buffer=buffer, metadata_path=METADATA_PATH) + + +def _set_up_schema_builder(): + """Sets up the schema_builder object.""" + global schema_builder + with open(SERVE_PATH, "rb") as f: + schema_builder = cloudpickle.load(f) + + +def _set_up_shared_libs(): + """Sets up the shared libs path.""" + if SHARED_LIBS_DIR.exists(): + # before importing, place dynamic linked libraries in shared lib path + shutil.copytree(SHARED_LIBS_DIR, "/lib", dirs_exist_ok=True) + + +def _convert_for_serialization(deserialized_data): + """Attempt to convert non-serializable objects to a serializable form. + + Args: + deserialized_data: The object to convert. + + Returns: + The converted object if it was not originally serializable, otherwise the original object. + """ + import numpy as np + import pandas as pd + + if isinstance(deserialized_data, np.ndarray): + return deserialized_data.tolist() + elif isinstance(deserialized_data, pd.DataFrame): + return deserialized_data.to_dict(orient="list") + elif isinstance(deserialized_data, pd.Series): + return deserialized_data.tolist() + return deserialized_data + + +# on import, execute +_run_preflight_diagnostics() +_set_up_schema_builder() +_set_up_shared_libs() diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py b/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py new file mode 100644 index 0000000000..e9aa4aafff --- /dev/null +++ b/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py @@ -0,0 +1,67 @@ +"""Module for artifacts preparation for tensorflow_serving""" + +from __future__ import absolute_import +from pathlib import Path +import shutil +from typing import List, Dict, Any + +from sagemaker.serve.model_format.mlflow.utils import ( + _get_saved_model_path_for_tensorflow_and_keras_flavor, + _move_contents, +) +from sagemaker.serve.detector.dependency_manager import capture_dependencies +from sagemaker.serve.validations.check_integrity import ( + generate_secret_key, + compute_hash, +) +from sagemaker.remote_function.core.serialization import _MetaData + + +def prepare_for_tf_serving( + model_path: str, + shared_libs: List[str], + dependencies: Dict[str, Any], +) -> str: + """Prepares the model for serving. + + Args: + model_path (str): Path to the model directory. + shared_libs (List[str]): List of shared libraries. + dependencies (Dict[str, Any]): Dictionary of dependencies. + + Returns: + str: Secret key. + """ + + _model_path = Path(model_path) + if not _model_path.exists(): + _model_path.mkdir() + elif not _model_path.is_dir(): + raise Exception("model_dir is not a valid directory") + + code_dir = _model_path.joinpath("code") + code_dir.mkdir(exist_ok=True) + shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir) + + shared_libs_dir = _model_path.joinpath("shared_libs") + shared_libs_dir.mkdir(exist_ok=True) + for shared_lib in shared_libs: + shutil.copy2(Path(shared_lib), shared_libs_dir) + + capture_dependencies(dependencies=dependencies, work_dir=code_dir) + + saved_model_bundle_dir = _model_path.joinpath("1") + saved_model_bundle_dir.mkdir(exist_ok=True) + mlflow_saved_model_dir = _get_saved_model_path_for_tensorflow_and_keras_flavor(model_path) + if not mlflow_saved_model_dir: + raise ValueError("SavedModel is not found for Tensorflow or Keras flavor.") + _move_contents(src_dir=mlflow_saved_model_dir, dest_dir=saved_model_bundle_dir) + + secret_key = generate_secret_key() + with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: + buffer = f.read() + hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: + metadata.write(_MetaData(hash_value).to_json()) + + return secret_key diff --git a/src/sagemaker/serve/model_server/tensorflow_serving/server.py b/src/sagemaker/serve/model_server/tensorflow_serving/server.py new file mode 100644 index 0000000000..45931e9afc --- /dev/null +++ b/src/sagemaker/serve/model_server/tensorflow_serving/server.py @@ -0,0 +1,148 @@ +"""Module for Local Tensorflow Server""" + +from __future__ import absolute_import + +import requests +import logging +import platform +from pathlib import Path +from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.utils.optimize_utils import _is_s3_uri +from sagemaker.session import Session +from sagemaker.serve.utils.exceptions import LocalModelInvocationException +from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url +from sagemaker import fw_utils +from sagemaker.serve.utils.uploader import upload +from sagemaker.local.utils import get_docker_host + +logger = logging.getLogger(__name__) + + +class LocalTensorflowServing: + """LocalTensorflowServing class.""" + + def _start_tensorflow_serving( + self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict + ): + """Starts a local tensorflow serving container. + + Args: + client: Docker client + image: Image to use + model_path: Path to the model + secret_key: Secret key to use for authentication + env_vars: Environment variables to set + """ + self.container = client.containers.run( + image, + "serve", + detach=True, + auto_remove=True, + network_mode="host", + volumes={ + Path(model_path): { + "bind": "/opt/ml/model", + "mode": "rw", + }, + }, + environment={ + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "LOCAL_PYTHON": platform.python_version(), + **env_vars, + }, + ) + + def _invoke_tensorflow_serving(self, request: object, content_type: str, accept: str): + """Invokes a local tensorflow serving container. + + Args: + request: Request to send + content_type: Content type to use + accept: Accept to use + """ + try: + response = requests.post( + f"http://{get_docker_host()}:8080/invocations", + data=request, + headers={"Content-Type": content_type, "Accept": accept}, + timeout=60, # this is what SageMaker Hosting uses as timeout + ) + response.raise_for_status() + return response.content + except Exception as e: + raise Exception("Unable to send request to the local container server") from e + + def _tensorflow_serving_deep_ping(self, predictor: PredictorBase): + """Checks if the local tensorflow serving container is up and running. + + If the container is not up and running, it will raise an exception. + """ + response = None + try: + response = predictor.predict(self.schema_builder.sample_input) + return (True, response) + # pylint: disable=broad-except + except Exception as e: + if "422 Client Error: Unprocessable Entity for url" in str(e): + raise LocalModelInvocationException(str(e)) + return (False, response) + + return (True, response) + + +class SageMakerTensorflowServing: + """SageMakerTensorflowServing class.""" + + def _upload_tensorflow_serving_artifacts( + self, + model_path: str, + sagemaker_session: Session, + secret_key: str, + s3_model_data_url: str = None, + image: str = None, + should_upload_artifacts: bool = False, + ): + """Uploads the model artifacts to S3. + + Args: + model_path: Path to the model + sagemaker_session: SageMaker session + secret_key: Secret key to use for authentication + s3_model_data_url: S3 model data URL + image: Image to use + model_data_s3_path: S3 model data URI + """ + s3_upload_path = None + if _is_s3_uri(model_path): + s3_upload_path = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) + + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", + bucket, + code_key_prefix, + ) + s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) + + env_vars = { + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_REGION": sagemaker_session.boto_region_name, + "SAGEMAKER_CONTAINER_LOG_LEVEL": "10", + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "LOCAL_PYTHON": platform.python_version(), + } + return s3_upload_path, env_vars diff --git a/src/sagemaker/serve/model_server/tgi/server.py b/src/sagemaker/serve/model_server/tgi/server.py index ef39e890c8..8ccc8e7ddc 100644 --- a/src/sagemaker/serve/model_server/tgi/server.py +++ b/src/sagemaker/serve/model_server/tgi/server.py @@ -12,11 +12,12 @@ from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join from sagemaker.s3 import S3Uploader from sagemaker.local.utils import get_docker_host +from sagemaker.serve.utils.optimize_utils import _is_s3_uri MODE_DIR_BINDING = "/opt/ml/model/" _SHM_SIZE = "2G" _DEFAULT_ENV_VARS = { - "TRANSFORMERS_CACHE": "/opt/ml/model/", + "HF_HOME": "/opt/ml/model/", "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", } @@ -110,38 +111,47 @@ def _upload_tgi_artifacts( s3_model_data_url: str = None, image: str = None, env_vars: dict = None, + should_upload_artifacts: bool = False, ): - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + model_data_url = None + if _is_s3_uri(model_path): + model_data_url = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - code_dir = Path(model_path).joinpath("code") + code_dir = Path(model_path).joinpath("code") - s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") + s3_location = s3_path_join("s3://", bucket, code_key_prefix, "code") - logger.debug("Uploading TGI Model Resources uncompressed to: %s", s3_location) + logger.debug("Uploading TGI Model Resources uncompressed to: %s", s3_location) - model_data_url = S3Uploader.upload( - str(code_dir), - s3_location, - None, - sagemaker_session, - ) + model_data_url = S3Uploader.upload( + str(code_dir), + s3_location, + None, + sagemaker_session, + ) - model_data = { - "S3DataSource": { - "CompressionType": "None", - "S3DataType": "S3Prefix", - "S3Uri": model_data_url + "/", + model_data = ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": model_data_url + "/", + } } - } + if model_data_url + else None + ) if jumpstart: return (model_data, {}) return (model_data, _update_env_vars(env_vars)) diff --git a/src/sagemaker/serve/model_server/torchserve/inference.py b/src/sagemaker/serve/model_server/torchserve/inference.py index 2675f6ea6a..058103a1fd 100644 --- a/src/sagemaker/serve/model_server/torchserve/inference.py +++ b/src/sagemaker/serve/model_server/torchserve/inference.py @@ -66,13 +66,39 @@ def input_fn(input_data, content_type): """Placeholder docstring""" try: if hasattr(schema_builder, "custom_input_translator"): - return schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data), content_type + deserialized_data = schema_builder.custom_input_translator.deserialize( + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type, ) else: - return schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data), content_type[0] + deserialized_data = schema_builder.input_deserializer.deserialize( + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type[0], ) + + # Check if preprocess method is defined and call it + if hasattr(inference_spec, "preprocess"): + return inference_spec.preprocess(deserialized_data) + + return deserialized_data except Exception as e: raise Exception("Encountered error in deserialize_request.") from e @@ -85,6 +111,8 @@ def predict_fn(input_data, predict_callable): def output_fn(predictions, accept_type): """Placeholder docstring""" try: + if hasattr(inference_spec, "postprocess"): + predictions = inference_spec.postprocess(predictions) if hasattr(schema_builder, "custom_output_translator"): return schema_builder.custom_output_translator.serialize(predictions, accept_type) else: diff --git a/src/sagemaker/serve/model_server/torchserve/server.py b/src/sagemaker/serve/model_server/torchserve/server.py index 5aef136355..74e37cd70b 100644 --- a/src/sagemaker/serve/model_server/torchserve/server.py +++ b/src/sagemaker/serve/model_server/torchserve/server.py @@ -7,6 +7,7 @@ import platform from pathlib import Path from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.utils.optimize_utils import _is_s3_uri from sagemaker.session import Session from sagemaker.serve.utils.exceptions import LocalModelInvocationException from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url @@ -84,24 +85,31 @@ def _upload_torchserve_artifacts( secret_key: str, s3_model_data_url: str = None, image: str = None, + should_upload_artifacts: bool = False, ): """Tar the model artifact and upload to S3 bucket, then prepare for the environment variables""" - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + s3_upload_path = None + if _is_s3_uri(model_path): + s3_upload_path = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - logger.debug( - "Uploading the model resources to bucket=%s, key_prefix=%s.", bucket, code_key_prefix - ) - s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) - logger.debug("Model resources uploaded to: %s", s3_upload_path) + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", + bucket, + code_key_prefix, + ) + s3_upload_path = upload(sagemaker_session, model_path, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) env_vars = { "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", diff --git a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py index 4e82ec66b2..49cec5aab5 100644 --- a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py +++ b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py @@ -70,11 +70,31 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): return schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data), content_type + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type, ) else: return schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data), content_type[0] + ( + io.BytesIO(input_data.encode("utf-8")) + if not any( + [ + isinstance(input_data, bytes), + isinstance(input_data, bytearray), + ] + ) + else io.BytesIO(input_data) + ), + content_type[0], ) except Exception as e: raise Exception("Encountered error in deserialize_request.") from e diff --git a/src/sagemaker/serve/model_server/triton/server.py b/src/sagemaker/serve/model_server/triton/server.py index 62dfb4759a..e2f3c20d7a 100644 --- a/src/sagemaker/serve/model_server/triton/server.py +++ b/src/sagemaker/serve/model_server/triton/server.py @@ -9,6 +9,7 @@ from sagemaker import fw_utils from sagemaker import Session from sagemaker.base_predictor import PredictorBase +from sagemaker.serve.utils.optimize_utils import _is_s3_uri from sagemaker.serve.utils.uploader import upload from sagemaker.serve.utils.exceptions import LocalModelInvocationException from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url @@ -115,25 +116,32 @@ def _upload_triton_artifacts( secret_key: str, s3_model_data_url: str = None, image: str = None, + should_upload_artifacts: bool = False, ): """Tar triton artifacts and upload to s3""" - if s3_model_data_url: - bucket, key_prefix = parse_s3_url(url=s3_model_data_url) - else: - bucket, key_prefix = None, None - - code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) - - bucket, code_key_prefix = determine_bucket_and_prefix( - bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session - ) + s3_upload_path = None + if _is_s3_uri(model_path): + s3_upload_path = model_path + elif should_upload_artifacts: + if s3_model_data_url: + bucket, key_prefix = parse_s3_url(url=s3_model_data_url) + else: + bucket, key_prefix = None, None + + code_key_prefix = fw_utils.model_code_key_prefix(key_prefix, None, image) + + bucket, code_key_prefix = determine_bucket_and_prefix( + bucket=bucket, key_prefix=code_key_prefix, sagemaker_session=sagemaker_session + ) - logger.debug( - "Uploading the model resources to bucket=%s, key_prefix=%s.", bucket, code_key_prefix - ) - model_repository = model_path + "/model_repository" - s3_upload_path = upload(sagemaker_session, model_repository, bucket, code_key_prefix) - logger.debug("Model resources uploaded to: %s", s3_upload_path) + logger.debug( + "Uploading the model resources to bucket=%s, key_prefix=%s.", + bucket, + code_key_prefix, + ) + model_repository = model_path + "/model_repository" + s3_upload_path = upload(sagemaker_session, model_repository, bucket, code_key_prefix) + logger.debug("Model resources uploaded to: %s", s3_upload_path) env_vars = { "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model", diff --git a/src/sagemaker/serve/model_server/triton/triton_builder.py b/src/sagemaker/serve/model_server/triton/triton_builder.py index ed0ec49204..c47991fa09 100644 --- a/src/sagemaker/serve/model_server/triton/triton_builder.py +++ b/src/sagemaker/serve/model_server/triton/triton_builder.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) -SUPPORTED_TRITON_MODE = {Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT} +SUPPORTED_TRITON_MODE = {Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS} SUPPORTED_TRITON_FRAMEWORK = {"pytorch", "tensorflow"} INPUT_NAME = "input_1" OUTPUT_NAME = "output_1" @@ -428,6 +428,10 @@ def _create_triton_model(self) -> Type[Model]: self.pysdk_model.mode = self.mode self.pysdk_model.modes = self.modes self.pysdk_model.serve_settings = self.serve_settings + if hasattr(self, "role_arn") and self.role_arn: + self.pysdk_model.role = self.role_arn + if hasattr(self, "sagemaker_session") and self.sagemaker_session: + self.pysdk_model.sagemaker_session = self.sagemaker_session # dynamically generate a method to direct model.deploy() logic based on mode # unique method to models created via ModelBuilder() diff --git a/src/sagemaker/serve/spec/inference_base.py b/src/sagemaker/serve/spec/inference_base.py new file mode 100644 index 0000000000..23ea6cb01d --- /dev/null +++ b/src/sagemaker/serve/spec/inference_base.py @@ -0,0 +1,45 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds templated classes to enable users to provide custom inference scripting capabilities""" +from __future__ import absolute_import +from abc import ABC, abstractmethod + + +class CustomOrchestrator(ABC): + """Templated class to standardize sync entrypoint-based inference scripts""" + + def __init__(self): + self._client = None + + @property + def client(self): + """Boto3 SageMaker runtime client to use with custom orchestrator""" + if not hasattr(self, "_client") or not self._client: + from boto3 import Session + + self._client = Session().client("sagemaker-runtime") + return self._client + + @abstractmethod + def handle(self, data, context=None): + """Abstract class for defining an entrypoint for the model server""" + return NotImplemented + + +class AsyncCustomOrchestrator(ABC): + """Templated class to standardize async entrypoint-based inference scripts""" + + @abstractmethod + async def handle(self, data, context=None): + """Abstract class for defining an aynchronous entrypoint for the model server""" + return NotImplemented diff --git a/src/sagemaker/serve/spec/inference_spec.py b/src/sagemaker/serve/spec/inference_spec.py index b61d7d55ea..0397e84975 100644 --- a/src/sagemaker/serve/spec/inference_spec.py +++ b/src/sagemaker/serve/spec/inference_spec.py @@ -28,5 +28,14 @@ def invoke(self, input_object: object, model: object): model (object): The model object """ + def preprocess(self, input_data: object): + """Custom pre-processing function""" + + def postprocess(self, predictions: object): + """Custom post-processing function""" + def prepare(self, *args, **kwargs): """Custom prepare function""" + + def get_model(self): + """Return HuggingFace model name for inference spec""" diff --git a/src/sagemaker/serve/utils/conda_in_process.yml b/src/sagemaker/serve/utils/conda_in_process.yml new file mode 100644 index 0000000000..d51754ec5a --- /dev/null +++ b/src/sagemaker/serve/utils/conda_in_process.yml @@ -0,0 +1,113 @@ +name: conda_env +channels: + - defaults +dependencies: + - accelerate>=0.24.1,<=0.27.0 + - sagemaker_schema_inference_artifacts>=0.0.5 + - uvicorn>=0.30.1 + - fastapi>=0.111.0 + - nest-asyncio + - pip>=23.0.1 + - attrs>=24,<26 + - boto3>=1.34.142,<2.0 + - cloudpickle==2.2.1 + - google-pasta + - numpy==1.26.4 + - protobuf>=3.12,<5.0 + - smdebug_rulesconfig==1.0.1 + - importlib-metadata>=1.4.0,<7.0 + - packaging>=23.0,<25 + - pandas + - pathos + - schema + - PyYAML>=6.0.1 + - jsonschema + - platformdirs + - tblib>=1.7.0,<4 + - urllib3>=1.26.8,<3.0.0 + - requests + - docker + - tqdm + - psutil + - pip: + - altair>=4.2.2 + - anyio>=3.6.2 + - awscli>=1.27.114 + - blinker>=1.6.2 + - botocore>=1.29.114 + - cachetools>=5.3.0 + - certifi==2022.12.7 + - charset-normalizer>=3.1.0 + - click>=8.1.3 + - cloudpickle==2.2.1 + - colorama>=0.4.4 + - contextlib2>=21.6.0 + - decorator>=5.1.1 + - dill>=0.3.9 + - docutils>=0.16 + - entrypoints>=0.4 + - filelock>=3.11.0 + - gitdb>=4.0.10 + - gitpython>=3.1.31 + - gunicorn>=20.1.0 + - h11>=0.14.0 + - huggingface-hub>=0.13.4 + - idna>=3.4 + - importlib-metadata>=4.13.0 + - jinja2>=3.1.2 + - jmespath>=1.0.1 + - jsonschema>=4.17.3 + - markdown-it-py>=2.2.0 + - markupsafe>=2.1.2 + - mdurl>=0.1.2 + - mpmath>=1.3.0 + - multiprocess>=0.70.14 + - networkx>=3.1 + - packaging>=23.1 + - pandas>=1.5.3 + - pathos>=0.3.0 + - pillow>=9.5.0 + - platformdirs>=3.2.0 + - pox>=0.3.2 + - ppft>=1.7.6.6 + - protobuf>=3.20.3 + - protobuf3-to-dict>=0.1.5 + - pyarrow>=11.0.0 + - pyasn1>=0.4.8 + - pydantic>=1.10.7 + - pydeck>=0.8.1b0 + - pygments>=2.15.1 + - pympler>=1.0.1 + - pyrsistent>=0.19.3 + - python-dateutil>=2.8.2 + - pytz>=2023.3 + - pytz-deprecation-shim>=0.1.0.post0 + - pyyaml>=6.0.1 + - regex>=2023.3.23 + - requests>=2.28.2 + - rich>=13.3.4 + - rsa>=4.7.2 + - s3transfer>=0.6.0 + - sagemaker>=2.148.0 + - schema>=0.7.5 + - six>=1.16.0 + - smdebug-rulesconfig>=1.0.1 + - smmap==5.0.0 + - sniffio>=1.3.0 + - starlette>=0.26.1 + - streamlit>=1.21.0 + - sympy>=1.11.1 + - tblib>=1.7.0 + - tokenizers>=0.13.3 + - toml>=0.10.2 + - toolz>=0.12.0 + - torch>=2.0.0 + - tornado>=6.3 + - tqdm>=4.65.0 + - transformers>=4.28.1 + - typing-extensions>=4.5.0 + - tzdata>=2023.3 + - tzlocal>=4.3 + - urllib3>=1.26.15 + - validators>=0.20.0 + - zipp>=3.15.0 diff --git a/src/sagemaker/serve/utils/exceptions.py b/src/sagemaker/serve/utils/exceptions.py index 72b9083072..eb22e8cce2 100644 --- a/src/sagemaker/serve/utils/exceptions.py +++ b/src/sagemaker/serve/utils/exceptions.py @@ -1,4 +1,4 @@ -"""Placeholder Docstring""" +"""Exceptions used across different model builder invocations""" from __future__ import absolute_import @@ -24,6 +24,16 @@ def __init__(self, message): super().__init__(message=message) +class InProcessDeepPingException(ModelBuilderException): + """Raise when in process model serving does not pass the deep ping check""" + + fmt = "Error Message: {message}" + model_builder_error_code = 1 + + def __init__(self, message): + super().__init__(message=message) + + class LocalModelOutOfMemoryException(ModelBuilderException): """Raise when local model serving fails to load the model""" diff --git a/src/sagemaker/serve/utils/hf_utils.py b/src/sagemaker/serve/utils/hf_utils.py new file mode 100644 index 0000000000..75f46eeeb9 --- /dev/null +++ b/src/sagemaker/serve/utils/hf_utils.py @@ -0,0 +1,53 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utility functions for fetching model information from HuggingFace Hub""" +from __future__ import absolute_import +import json +import urllib.request +from json import JSONDecodeError +from urllib.error import HTTPError, URLError +import logging + +logger = logging.getLogger(__name__) + + +def _get_model_config_properties_from_hf(model_id: str, hf_hub_token: str = None): + """Placeholder docstring""" + + config_url = f"https://huggingface.co/{model_id}/raw/main/config.json" + model_config = None + try: + if hf_hub_token: + config_url = urllib.request.Request( + config_url, headers={"Authorization": "Bearer " + hf_hub_token} + ) + with urllib.request.urlopen(config_url) as response: + model_config = json.load(response) + except (HTTPError, URLError, TimeoutError, JSONDecodeError) as e: + if "HTTP Error 401: Unauthorized" in str(e): + raise ValueError( + "Trying to access a gated/private HuggingFace model without valid credentials. " + "Please provide a HUGGING_FACE_HUB_TOKEN in env_vars" + ) + logger.warning( + "Exception encountered while trying to read config file %s. " "Details: %s", + config_url, + e, + ) + if not model_config: + raise ValueError( + f"Did not find a config.json or model_index.json file in huggingface hub for " + f"{model_id}. Please make sure a config.json exists (or model_index.json for Stable " + f"Diffusion Models) for this model in the huggingface hub" + ) + return model_config diff --git a/src/sagemaker/serve/utils/in_process_requirements.txt b/src/sagemaker/serve/utils/in_process_requirements.txt new file mode 100644 index 0000000000..da1fd8e617 --- /dev/null +++ b/src/sagemaker/serve/utils/in_process_requirements.txt @@ -0,0 +1,85 @@ +altair>=4.2.2 +anyio>=3.6.2 +awscli>=1.27.114 +blinker>=1.6.2 +botocore>=1.29.114 +cachetools>=5.3.0 +certifi==2024.7.4 +charset-normalizer>=3.1.0 +click>=8.1.3 +cloudpickle==2.2.1 +colorama>=0.4.4 +contextlib2>=21.6.0 +decorator>=5.1.1 +dill>=0.3.9 +docutils>=0.16 +entrypoints>=0.4 +filelock>=3.11.0 +gitdb>=4.0.10 +gitpython>=3.1.31 +gunicorn>=20.1.0 +h11>=0.14.0 +huggingface-hub>=0.13.4 +idna>=3.4 +importlib-metadata>=4.13.0 +jinja2>=3.1.2 +jmespath>=1.0.1 +jsonschema>=4.17.3 +markdown-it-py>=2.2.0 +markupsafe>=2.1.2 +mdurl>=0.1.2 +mpmath>=1.3.0 +multiprocess>=0.70.14 +networkx>=3.1 +packaging>=23.1 +pandas>=1.5.3 +pathos>=0.3.0 +pillow>=9.5.0 +platformdirs>=3.2.0 +pox>=0.3.2 +ppft>=1.7.6.6 +protobuf>=3.20.3 +protobuf3-to-dict>=0.1.5 +pyarrow>=11.0.0 +pyasn1>=0.4.8 +pydantic>=1.10.7 +pydeck>=0.8.1b0 +pygments>=2.15.1 +pympler>=1.0.1 +pyrsistent>=0.19.3 +python-dateutil>=2.8.2 +pytz>=2023.3 +pytz-deprecation-shim>=0.1.0.post0 +pyyaml>=6.0.1 +regex>=2023.3.23 +requests>=2.28.2 +rich>=13.3.4 +rsa>=4.7.2 +s3transfer>=0.6.0 +sagemaker>=2.148.0 +schema>=0.7.5 +six>=1.16.0 +smdebug-rulesconfig>=1.0.1 +smmap==5.0.0 +sniffio>=1.3.0 +starlette>=0.26.1 +streamlit>=1.21.0 +sympy>=1.11.1 +tblib>=1.7.0 +tokenizers>=0.13.3 +toml>=0.10.2 +toolz>=0.12.0 +torch>=2.0.0 +tornado>=6.3 +tqdm>=4.65.0 +transformers>=4.28.1 +typing-extensions>=4.5.0 +tzdata>=2023.3 +tzlocal>=4.3 +urllib3>=1.26.15 +validators>=0.20.0 +zipp>=3.15.0 +uvicorn>=0.30.1 +fastapi>=0.111.0 +nest-asyncio +transformers diff --git a/src/sagemaker/serve/utils/lineage_constants.py b/src/sagemaker/serve/utils/lineage_constants.py new file mode 100644 index 0000000000..dce4a41139 --- /dev/null +++ b/src/sagemaker/serve/utils/lineage_constants.py @@ -0,0 +1,30 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds constants used for lineage support""" +from __future__ import absolute_import + + +LINEAGE_POLLER_INTERVAL_SECS = 15 +LINEAGE_POLLER_MAX_TIMEOUT_SECS = 120 +TRACKING_SERVER_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):mlflow-tracking-server/(.*?)$" +TRACKING_SERVER_CREATION_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" +MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE = "ModelBuilderInputModelData" +MLFLOW_S3_PATH = "S3" +MLFLOW_MODEL_PACKAGE_PATH = "ModelPackage" +MLFLOW_RUN_ID = "MLflowRunId" +MLFLOW_LOCAL_PATH = "Local" +MLFLOW_REGISTRY_PATH = "MLflowRegistry" +ERROR = "Error" +CODE = "Code" +CONTRIBUTED_TO = "ContributedTo" +VALIDATION_EXCEPTION = "ValidationException" diff --git a/src/sagemaker/serve/utils/lineage_utils.py b/src/sagemaker/serve/utils/lineage_utils.py new file mode 100644 index 0000000000..7278dd8a3c --- /dev/null +++ b/src/sagemaker/serve/utils/lineage_utils.py @@ -0,0 +1,328 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds the util functions used for lineage tracking""" +from __future__ import absolute_import + +import os +import time +import re +import logging +from typing import List, Optional, Union + +from botocore.exceptions import ClientError + +from sagemaker import Session +from sagemaker.lineage._api_types import ArtifactSummary +from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.association import Association +from sagemaker.lineage.query import LineageSourceEnum +from sagemaker.serve.model_format.mlflow.constants import ( + MLFLOW_RUN_ID_REGEX, + MODEL_PACKAGE_ARN_REGEX, + S3_PATH_REGEX, + MLFLOW_REGISTRY_PATH_REGEX, +) +from sagemaker.serve.utils.lineage_constants import ( + LINEAGE_POLLER_MAX_TIMEOUT_SECS, + LINEAGE_POLLER_INTERVAL_SECS, + TRACKING_SERVER_ARN_REGEX, + TRACKING_SERVER_CREATION_TIME_FORMAT, + MLFLOW_S3_PATH, + MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, + MLFLOW_LOCAL_PATH, + MLFLOW_MODEL_PACKAGE_PATH, + MLFLOW_RUN_ID, + MLFLOW_REGISTRY_PATH, + CONTRIBUTED_TO, + ERROR, + CODE, + VALIDATION_EXCEPTION, +) + +logger = logging.getLogger(__name__) + + +def _load_artifact_by_source_uri( + source_uri: str, + sagemaker_session: Session, + source_types_to_match: Optional[List[str]] = None, + artifact_type: Optional[str] = None, +) -> Optional[ArtifactSummary]: + """Load lineage artifact by source uri + + Arguments: + source_uri (str): The s3 uri used for uploading transfomred model artifacts. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + source_types_to_match (Optional[List[str]]): A list of source type values to match against + the artifact's source types. If provided, the artifact's source types must match this + list. + artifact_type (Optional[str]): The type of the lineage artifact. + + Returns: + ArtifactSummary: The Artifact Summary for the provided S3 URI. + """ + artifacts = Artifact.list(source_uri=source_uri, sagemaker_session=sagemaker_session) + for artifact_summary in artifacts: + if artifact_type is None or artifact_summary.artifact_type == artifact_type: + if source_types_to_match: + if artifact_summary.source.source_types is not None: + artifact_source_types = [ + source_type["Value"] for source_type in artifact_summary.source.source_types + ] + if set(artifact_source_types) == set(source_types_to_match): + return artifact_summary + else: + return None + else: + return artifact_summary + + return None + + +def _poll_lineage_artifact( + s3_uri: str, artifact_type: str, sagemaker_session: Session +) -> Optional[ArtifactSummary]: + """Polls lineage artifacts by s3 path. + + Arguments: + s3_uri (str): The S3 URI to check for artifacts. + artifact_type (str): The type of the lineage artifact. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Optional[ArtifactSummary]: The artifact summary if found, otherwise None. + """ + logger.info("Polling lineage artifact for model data in %s", s3_uri) + start_time = time.time() + while time.time() - start_time < LINEAGE_POLLER_MAX_TIMEOUT_SECS: + result = _load_artifact_by_source_uri( + s3_uri, sagemaker_session, artifact_type=artifact_type + ) + if result is not None: + return result + time.sleep(LINEAGE_POLLER_INTERVAL_SECS) + + +def _get_mlflow_model_path_type(mlflow_model_path: str) -> str: + """Identify mlflow model path type. + + Args: + mlflow_model_path (str): The string to be identified. + + Returns: + str: Description of what the input string is identified as. + """ + mlflow_run_id_pattern = MLFLOW_RUN_ID_REGEX + mlflow_registry_id_pattern = MLFLOW_REGISTRY_PATH_REGEX + sagemaker_arn_pattern = MODEL_PACKAGE_ARN_REGEX + s3_pattern = S3_PATH_REGEX + + if re.match(mlflow_run_id_pattern, mlflow_model_path): + return MLFLOW_RUN_ID + if re.match(mlflow_registry_id_pattern, mlflow_model_path): + return MLFLOW_REGISTRY_PATH + if re.match(sagemaker_arn_pattern, mlflow_model_path): + return MLFLOW_MODEL_PACKAGE_PATH + if re.match(s3_pattern, mlflow_model_path): + return MLFLOW_S3_PATH + if os.path.exists(mlflow_model_path): + return MLFLOW_LOCAL_PATH + + raise ValueError(f"Invalid MLflow model path: {mlflow_model_path}") + + +def _create_mlflow_model_path_lineage_artifact( + mlflow_model_path: str, + sagemaker_session: Session, + source_types_to_match: Optional[List[str]] = None, +) -> Optional[Artifact]: + """Creates a lineage artifact for the given MLflow model path. + + Args: + mlflow_model_path (str): The path to the MLflow model. + sagemaker_session (Session): The SageMaker session object. + source_types_to_match (Optional[List[str]]): Artifact source types. + + Returns: + Optional[Artifact]: The created lineage artifact, or None if an error occurred. + """ + _artifact_name = _get_mlflow_model_path_type(mlflow_model_path) + properties = dict( + model_builder_input_model_data_type=_artifact_name, + ) + try: + source_types = [dict(SourceIdType="Custom", Value="ModelBuilderInputModelData")] + if source_types_to_match: + source_types += [ + dict(SourceIdType="Custom", Value=source_type) + for source_type in source_types_to_match + if source_type != "ModelBuilderInputModelData" + ] + + return Artifact.create( + source_uri=mlflow_model_path, + source_types=source_types, + artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, + artifact_name=_artifact_name, + properties=properties, + sagemaker_session=sagemaker_session, + ) + except ClientError as e: + if e.response[ERROR][CODE] == VALIDATION_EXCEPTION: + logger.info("Artifact already exists") + else: + logger.warning("Failed to create mlflow model path lineage artifact: %s", e) + raise e + + +def _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( + mlflow_model_path: str, + sagemaker_session: Session, + tracking_server_arn: Optional[str] = None, +) -> Optional[Union[Artifact, ArtifactSummary]]: + """Retrieves an existing artifact for the given MLflow model path or + + creates a new one if it doesn't exist. + + Args: + mlflow_model_path (str): The path to the MLflow model. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + tracking_server_arn (Optional[str]): The MLflow tracking server ARN. + + Returns: + Optional[Union[Artifact, ArtifactSummary]]: The existing or newly created artifact, + or None if an error occurred. + """ + source_types_to_match = ["ModelBuilderInputModelData"] + input_type = _get_mlflow_model_path_type(mlflow_model_path) + if tracking_server_arn and input_type in [MLFLOW_RUN_ID, MLFLOW_REGISTRY_PATH]: + match = re.match(TRACKING_SERVER_ARN_REGEX, tracking_server_arn) + mlflow_tracking_server_name = match.group(4) + describe_result = sagemaker_session.sagemaker_client.describe_mlflow_tracking_server( + TrackingServerName=mlflow_tracking_server_name + ) + tracking_server_creation_time = describe_result["CreationTime"].strftime( + TRACKING_SERVER_CREATION_TIME_FORMAT + ) + source_types_to_match += [tracking_server_arn, tracking_server_creation_time] + _loaded_artifact = _load_artifact_by_source_uri( + mlflow_model_path, + sagemaker_session, + source_types_to_match, + ) + if _loaded_artifact is not None: + return _loaded_artifact + return _create_mlflow_model_path_lineage_artifact( + mlflow_model_path, + sagemaker_session, + source_types_to_match, + ) + + +def _add_association_between_artifacts( + mlflow_model_path_artifact_arn: str, + autogenerated_model_data_artifact_arn: str, + sagemaker_session: Session, +) -> None: + """Add association between mlflow model path artifact and autogenerated model data artifact. + + Arguments: + mlflow_model_path_artifact_arn (str): The mlflow model path artifact. + autogenerated_model_data_artifact_arn (str): The autogenerated model data artifact. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _association_type = CONTRIBUTED_TO + _source_arn = mlflow_model_path_artifact_arn + _destination_arn = autogenerated_model_data_artifact_arn + try: + logger.info( + "Adding association with source_arn: " + "%s, destination_arn: %s and association_type: %s.", + _source_arn, + _destination_arn, + _association_type, + ) + Association.create( + source_arn=_source_arn, + destination_arn=_destination_arn, + association_type=_association_type, + sagemaker_session=sagemaker_session, + ) + except ClientError as e: + if e.response[ERROR][CODE] == VALIDATION_EXCEPTION: + logger.info("Association already exists") + else: + raise e + + +def _maintain_lineage_tracking_for_mlflow_model( + mlflow_model_path: str, + s3_upload_path: str, + sagemaker_session: Session, + tracking_server_arn: Optional[str] = None, +) -> None: + """Maintains lineage tracking for an MLflow model by creating or retrieving artifacts. + + Args: + mlflow_model_path (str): The path to the MLflow model. + s3_upload_path (str): The S3 path where the transformed model data is uploaded. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + tracking_server_arn (Optional[str]): The MLflow tracking server ARN. + """ + artifact_for_transformed_model_data = _poll_lineage_artifact( + s3_uri=s3_upload_path, + artifact_type=LineageSourceEnum.MODEL_DATA.value, + sagemaker_session=sagemaker_session, + ) + if artifact_for_transformed_model_data: + mlflow_model_artifact = ( + _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( + mlflow_model_path=mlflow_model_path, + sagemaker_session=sagemaker_session, + tracking_server_arn=tracking_server_arn, + ) + ) + if mlflow_model_artifact: + _mlflow_model_artifact_arn = ( + mlflow_model_artifact.artifact_arn + ) # pylint: disable=E1101, disable=C0301 + _artifact_for_transformed_model_data_arn = ( + artifact_for_transformed_model_data.artifact_arn + ) # pylint: disable=C0301 + _add_association_between_artifacts( + mlflow_model_path_artifact_arn=_mlflow_model_artifact_arn, + autogenerated_model_data_artifact_arn=_artifact_for_transformed_model_data_arn, + sagemaker_session=sagemaker_session, + ) + else: + logger.warning( + "Unable to add association between autogenerated lineage " + "artifact for transformed model data and mlflow model path" + " lineage artifacts." + ) + else: + logger.warning( + "Lineage artifact for transformed model data is not auto-created within " + "%s seconds, skipping creation of lineage artifact for mlflow model path", + LINEAGE_POLLER_MAX_TIMEOUT_SECS, + ) diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py new file mode 100644 index 0000000000..68ed1e846d --- /dev/null +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -0,0 +1,503 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds the util functions used for the optimize function""" +from __future__ import absolute_import + +import re +import logging +from typing import Dict, Any, Optional, Union, List, Tuple + +from sagemaker import Model, Session +from sagemaker.enums import Tag +from sagemaker.jumpstart.utils import accessors, get_eula_message + + +logger = logging.getLogger(__name__) + + +SPECULATIVE_DRAFT_MODEL = "/opt/ml/additional-model-data-sources" + + +def _is_inferentia_or_trainium(instance_type: Optional[str]) -> bool: + """Checks whether an instance is compatible with Inferentia. + + Args: + instance_type (str): The instance type used for the compilation job. + + Returns: + bool: Whether the given instance type is Inferentia or Trainium. + """ + if isinstance(instance_type, str): + match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + if match: + if match[1].startswith("inf") or match[1].startswith("trn"): + return True + return False + + +def _is_image_compatible_with_optimization_job(image_uri: Optional[str]) -> bool: + """Checks whether an instance is compatible with an optimization job. + + Args: + image_uri (str): The image URI of the optimization job. + + Returns: + bool: Whether the given instance type is compatible with an optimization job. + """ + # TODO: Use specific container type instead. + if image_uri is None: + return True + return "djl-inference:" in image_uri and ("-lmi" in image_uri or "-neuronx-" in image_uri) + + +def _deployment_config_contains_draft_model(deployment_config: Optional[Dict]) -> bool: + """Checks whether a deployment config contains a speculative decoding draft model. + + Args: + deployment_config (Dict): The deployment config to check. + + Returns: + bool: Whether the deployment config contains a draft model or not. + """ + if deployment_config is None: + return False + deployment_args = deployment_config.get("DeploymentArgs", {}) + additional_data_sources = deployment_args.get("AdditionalDataSources") + + return "speculative_decoding" in additional_data_sources if additional_data_sources else False + + +def _is_draft_model_jumpstart_provided(deployment_config: Optional[Dict]) -> bool: + """Checks whether a deployment config's draft model is provided by JumpStart. + + Args: + deployment_config (Dict): The deployment config to check. + + Returns: + bool: Whether the draft model is provided by JumpStart or not. + """ + if deployment_config is None: + return False + + additional_model_data_sources = deployment_config.get("DeploymentArgs", {}).get( + "AdditionalDataSources" + ) + for source in additional_model_data_sources.get("speculative_decoding", []): + if source["channel_name"] == "draft_model": + if source.get("provider", {}).get("name") == "JumpStart": + return True + continue + return False + + +def _generate_optimized_model(pysdk_model: Model, optimization_response: dict) -> Model: + """Generates a new optimization model. + + Args: + pysdk_model (Model): A PySDK model. + optimization_response (dict): The optimization response. + + Returns: + Model: A deployable optimized model. + """ + recommended_image_uri = optimization_response.get("OptimizationOutput", {}).get( + "RecommendedInferenceImage" + ) + s3_uri = optimization_response.get("OutputConfig", {}).get("S3OutputLocation") + deployment_instance_type = optimization_response.get("DeploymentInstanceType") + + if recommended_image_uri: + pysdk_model.image_uri = recommended_image_uri + if s3_uri: + pysdk_model.model_data["S3DataSource"]["S3Uri"] = s3_uri + if deployment_instance_type: + pysdk_model.instance_type = deployment_instance_type + + pysdk_model.add_tags( + {"Key": Tag.OPTIMIZATION_JOB_NAME, "Value": optimization_response["OptimizationJobName"]} + ) + return pysdk_model + + +def _is_optimized(pysdk_model: Model) -> bool: + """Checks whether an optimization model is optimized. + + Args: + pysdk_model (Model): A PySDK model. + + Return: + bool: Whether the given model type is optimized. + """ + optimized_tags = [Tag.OPTIMIZATION_JOB_NAME, Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER] + if hasattr(pysdk_model, "_tags") and pysdk_model._tags: + if isinstance(pysdk_model._tags, dict): + return pysdk_model._tags.get("Key") in optimized_tags + for tag in pysdk_model._tags: + if tag.get("Key") in optimized_tags: + return True + return False + + +def _generate_model_source( + model_data: Optional[Union[Dict[str, Any], str]], accept_eula: Optional[bool] +) -> Optional[Dict[str, Any]]: + """Extracts model source from model data. + + Args: + model_data (Optional[Union[Dict[str, Any], str]]): A model data. + + Returns: + Optional[Dict[str, Any]]: Model source data. + """ + if model_data is None: + raise ValueError("Model Optimization Job only supports model with S3 data source.") + + s3_uri = model_data + if isinstance(s3_uri, dict): + s3_uri = s3_uri.get("S3DataSource").get("S3Uri") + + model_source = {"S3": {"S3Uri": s3_uri}} + if accept_eula: + model_source["S3"]["ModelAccessConfig"] = {"AcceptEula": True} + return model_source + + +def _update_environment_variables( + env: Optional[Dict[str, str]], new_env: Optional[Dict[str, str]] +) -> Optional[Dict[str, str]]: + """Updates environment variables based on environment variables. + + Args: + env (Optional[Dict[str, str]]): The environment variables. + new_env (Optional[Dict[str, str]]): The new environment variables. + + Returns: + Optional[Dict[str, str]]: The updated environment variables. + """ + if new_env: + if env: + env.update(new_env) + else: + env = new_env + return env + + +def _extract_speculative_draft_model_provider( + speculative_decoding_config: Optional[Dict] = None, +) -> Optional[str]: + """Extracts speculative draft model provider from speculative decoding config. + + Args: + speculative_decoding_config (Optional[Dict]): A speculative decoding config. + + Returns: + Optional[str]: The speculative draft model provider. + """ + if speculative_decoding_config is None: + return None + + model_provider = speculative_decoding_config.get("ModelProvider", "").lower() + + if model_provider == "jumpstart": + return "jumpstart" + + if model_provider == "custom" or speculative_decoding_config.get("ModelSource"): + return "custom" + + if model_provider == "sagemaker": + return "sagemaker" + + return "auto" + + +def _extract_additional_model_data_source_s3_uri( + additional_model_data_source: Optional[Dict] = None, +) -> Optional[str]: + """Extracts model data source s3 uri from a model data source in Pascal case. + + Args: + additional_model_data_source (Optional[Dict]): A model data source. + + Returns: + str: S3 uri of the model resources. + """ + if ( + additional_model_data_source is None + or additional_model_data_source.get("S3DataSource", None) is None + ): + return None + + return additional_model_data_source.get("S3DataSource").get("S3Uri") + + +def _extract_deployment_config_additional_model_data_source_s3_uri( + additional_model_data_source: Optional[Dict] = None, +) -> Optional[str]: + """Extracts model data source s3 uri from a model data source in snake case. + + Args: + additional_model_data_source (Optional[Dict]): A model data source. + + Returns: + str: S3 uri of the model resources. + """ + if ( + additional_model_data_source is None + or additional_model_data_source.get("s3_data_source", None) is None + ): + return None + + return additional_model_data_source.get("s3_data_source").get("s3_uri", None) + + +def _is_draft_model_gated( + draft_model_config: Optional[Dict] = None, +) -> bool: + """Extracts model gated-ness from draft model data source. + + Args: + draft_model_config (Optional[Dict]): A model data source. + + Returns: + bool: Whether the draft model is gated or not. + """ + return "hosting_eula_key" in draft_model_config if draft_model_config else False + + +def _extracts_and_validates_speculative_model_source( + speculative_decoding_config: Dict, +) -> str: + """Extracts model source from speculative decoding config. + + Args: + speculative_decoding_config (Optional[Dict]): A speculative decoding config. + + Returns: + str: Model source. + + Raises: + ValueError: If model source is none. + """ + model_source: str = speculative_decoding_config.get("ModelSource") + + if not model_source: + raise ValueError("ModelSource must be provided in speculative decoding config.") + return model_source + + +def _generate_channel_name(additional_model_data_sources: Optional[List[Dict]]) -> str: + """Generates a channel name. + + Args: + additional_model_data_sources (Optional[List[Dict]]): The additional model data sources. + + Returns: + str: The channel name. + """ + channel_name = "draft_model" + if additional_model_data_sources and len(additional_model_data_sources) > 0: + channel_name = additional_model_data_sources[0].get("ChannelName", channel_name) + + return channel_name + + +def _generate_additional_model_data_sources( + model_source: str, + channel_name: str, + accept_eula: bool = False, + s3_data_type: Optional[str] = "S3Prefix", + compression_type: Optional[str] = "None", +) -> List[Dict]: + """Generates additional model data sources. + + Args: + model_source (Optional[str]): The model source. + channel_name (Optional[str]): The channel name. + accept_eula (Optional[bool]): Whether to accept eula or not. + s3_data_type (Optional[str]): The S3 data type, defaults to 'S3Prefix'. + compression_type (Optional[str]): The compression type, defaults to None. + + Returns: + List[Dict]: The additional model data sources. + """ + + additional_model_data_source = { + "ChannelName": channel_name, + "S3DataSource": { + "S3Uri": model_source, + "S3DataType": s3_data_type, + "CompressionType": compression_type, + }, + } + if accept_eula: + additional_model_data_source["S3DataSource"]["ModelAccessConfig"] = {"AcceptEula": True} + + return [additional_model_data_source] + + +def _is_s3_uri(s3_uri: Optional[str]) -> bool: + """Checks whether an S3 URI is valid. + + Args: + s3_uri (Optional[str]): The S3 URI. + + Returns: + bool: Whether the S3 URI is valid. + """ + if s3_uri is None: + return False + + return re.match("^s3://([^/]+)/?(.*)$", s3_uri) is not None + + +def _extract_optimization_config_and_env( + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, +) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]: + """Extracts optimization config and environment variables. + + Args: + quantization_config (Optional[Dict]): The quantization config. + compilation_config (Optional[Dict]): The compilation config. + sharding_config (Optional[Dict]): The sharding config. + + Returns: + Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]: + The optimization config and environment variables. + """ + optimization_config = {} + quantization_override_env = ( + quantization_config.get("OverrideEnvironment") if quantization_config else None + ) + compilation_override_env = ( + compilation_config.get("OverrideEnvironment") if compilation_config else None + ) + sharding_override_env = sharding_config.get("OverrideEnvironment") if sharding_config else None + + if quantization_config is not None: + optimization_config["ModelQuantizationConfig"] = quantization_config + + if compilation_config is not None: + optimization_config["ModelCompilationConfig"] = compilation_config + + if sharding_config is not None: + optimization_config["ModelShardingConfig"] = sharding_config + + # Return optimization config dict and environment variables if either is present + if optimization_config: + return ( + optimization_config, + quantization_override_env, + compilation_override_env, + sharding_override_env, + ) + + return None, None, None, None + + +def _custom_speculative_decoding( + model: Model, + speculative_decoding_config: Optional[Dict], + accept_eula: Optional[bool] = False, +) -> Model: + """Modifies the given model for speculative decoding config with custom provider. + + Args: + model (Model): The model. + speculative_decoding_config (Optional[Dict]): The speculative decoding config. + accept_eula (Optional[bool]): Whether to accept eula or not. + """ + + if speculative_decoding_config: + additional_model_source = _extracts_and_validates_speculative_model_source( + speculative_decoding_config + ) + + accept_eula = speculative_decoding_config.get("AcceptEula", accept_eula) + + if _is_s3_uri(additional_model_source): + channel_name = _generate_channel_name(model.additional_model_data_sources) + speculative_draft_model = f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}" + + model.additional_model_data_sources = _generate_additional_model_data_sources( + additional_model_source, channel_name, accept_eula + ) + else: + speculative_draft_model = additional_model_source + + model.env.update({"OPTION_SPECULATIVE_DRAFT_MODEL": speculative_draft_model}) + model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "custom"}, + ) + + return model + + +def _jumpstart_speculative_decoding( + model=Model, + speculative_decoding_config: Optional[Dict[str, Any]] = None, + sagemaker_session: Optional[Session] = None, +): + """Modifies the given model for speculative decoding config with JumpStart provider. + + Args: + model (Model): The model. + speculative_decoding_config (Optional[Dict]): The speculative decoding config. + sagemaker_session (Optional[Session]): Sagemaker session for execution. + """ + if speculative_decoding_config: + js_id = speculative_decoding_config.get("ModelID") + if not js_id: + raise ValueError( + "`ModelID` is a required field in `speculative_decoding_config` when " + "using JumpStart as draft model provider." + ) + model_version = speculative_decoding_config.get("ModelVersion", "*") + accept_eula = speculative_decoding_config.get("AcceptEula", False) + channel_name = _generate_channel_name(model.additional_model_data_sources) + + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + model_id=js_id, + version=model_version, + region=sagemaker_session.boto_region_name, + sagemaker_session=sagemaker_session, + ) + model_spec_json = model_specs.to_json() + + js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket() + + if model_spec_json.get("gated_bucket", False): + if not accept_eula: + eula_message = get_eula_message( + model_specs=model_specs, region=sagemaker_session.boto_region_name + ) + raise ValueError( + f"{eula_message} Set `AcceptEula`=True in " + f"speculative_decoding_config once acknowledged." + ) + js_bucket = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket() + + key_prefix = model_spec_json.get("hosting_prepacked_artifact_key") + model.additional_model_data_sources = _generate_additional_model_data_sources( + f"s3://{js_bucket}/{key_prefix}", + channel_name, + accept_eula, + ) + + model.env.update( + {"OPTION_SPECULATIVE_DRAFT_MODEL": f"{SPECULATIVE_DRAFT_MODEL}/{channel_name}/"} + ) + model.add_tags( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"}, + ) diff --git a/src/sagemaker/serve/utils/predictors.py b/src/sagemaker/serve/utils/predictors.py index e0ff8f8ee1..af05de6425 100644 --- a/src/sagemaker/serve/utils/predictors.py +++ b/src/sagemaker/serve/utils/predictors.py @@ -3,9 +3,10 @@ from __future__ import absolute_import import io from typing import Type - +import logging from sagemaker import Session from sagemaker.serve.mode.local_container_mode import LocalContainerMode +from sagemaker.serve.mode.in_process_mode import InProcessMode from sagemaker.serve.builder.schema_builder import SchemaBuilder from sagemaker.serializers import IdentitySerializer, JSONSerializer from sagemaker.deserializers import BytesDeserializer, JSONDeserializer @@ -15,6 +16,8 @@ APPLICATION_X_NPY = "application/x-npy" +logger = logging.getLogger(__name__) + class TorchServeLocalPredictor(PredictorBase): """Lightweight predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes""" @@ -209,6 +212,90 @@ def delete_predictor(self): self._mode_obj.destroy_server() +class TeiLocalModePredictor(PredictorBase): + """Lightweight Tei predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes""" + + def __init__( + self, + mode_obj: Type[LocalContainerMode], + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ): + self._mode_obj = mode_obj + self.serializer = serializer + self.deserializer = deserializer + + def predict(self, data): + """Placeholder docstring""" + return [ + self.deserializer.deserialize( + io.BytesIO( + self._mode_obj._invoke_serving( + self.serializer.serialize(data), + self.content_type, + self.deserializer.ACCEPT[0], + ) + ), + self.content_type, + ) + ] + + @property + def content_type(self): + """The MIME type of the data sent to the inference endpoint.""" + return self.serializer.CONTENT_TYPE + + @property + def accept(self): + """The content type(s) that are expected from the inference endpoint.""" + return self.deserializer.ACCEPT + + def delete_predictor(self): + """Shut down and remove the container that you created in LOCAL_CONTAINER mode""" + self._mode_obj.destroy_server() + + +class TensorflowServingLocalPredictor(PredictorBase): + """Lightweight predictor for local deployment in LOCAL_CONTAINER modes""" + + # TODO: change mode_obj to union of IN_PROCESS and LOCAL_CONTAINER objs + def __init__( + self, + mode_obj: Type[LocalContainerMode], + serializer=IdentitySerializer(), + deserializer=BytesDeserializer(), + ): + self._mode_obj = mode_obj + self.serializer = serializer + self.deserializer = deserializer + + def predict(self, data): + """Placeholder docstring""" + return self.deserializer.deserialize( + io.BytesIO( + self._mode_obj._invoke_tensorflow_serving( + self.serializer.serialize(data), + self.content_type, + self.accept[0], + ) + ) + ) + + @property + def content_type(self): + """The MIME type of the data sent to the inference endpoint.""" + return self.serializer.CONTENT_TYPE + + @property + def accept(self): + """The content type(s) that are expected from the inference endpoint.""" + return self.deserializer.ACCEPT + + def delete_predictor(self): + """Shut down and remove the container that you created in LOCAL_CONTAINER mode""" + self._mode_obj.destroy_server() + + def _get_local_mode_predictor( model_server: ModelServer, mode_obj: Type[LocalContainerMode], @@ -223,9 +310,66 @@ def _get_local_mode_predictor( if model_server == ModelServer.TRITON: return TritonLocalPredictor(mode_obj=mode_obj) + if model_server == ModelServer.TENSORFLOW_SERVING: + return TensorflowServingLocalPredictor( + mode_obj=mode_obj, serializer=serializer, deserializer=deserializer + ) + raise ValueError("%s model server is not supported yet!" % model_server) +class InProcessModePredictor(PredictorBase): + """Lightweight predictor for in process mode deployment""" + + def __init__( + self, + mode_obj: Type[InProcessMode], + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ): + self._mode_obj = mode_obj + self.serializer = serializer + self.deserializer = deserializer + + def predict(self, data): + """Placeholder docstring""" + return self.deserializer.deserialize( + io.BytesIO( + self._mode_obj._invoke_serving( + self.serializer.serialize(data), + self.content_type, + self.accept[0], + ) + ) + ) + + @property + def content_type(self): + """The MIME type of the data sent to the inference endpoint.""" + return self.serializer.CONTENT_TYPE + + @property + def accept(self): + """The content type(s) that are expected from the inference endpoint.""" + return self.deserializer.ACCEPT + + def delete_predictor(self): + """Shut down and remove the container that you created in IN_PROCESS mode""" + self._mode_obj.destroy_server() + + +def _get_in_process_mode_predictor( + # model_server: ModelServer, + mode_obj: Type[InProcessMode], + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), +) -> Type[PredictorBase]: + """Returns Predictor for IN_PROCESS mode""" + return InProcessModePredictor( + mode_obj=mode_obj, serializer=serializer, deserializer=deserializer + ) + + def retrieve_predictor( endpoint_name: str, schema_builder: SchemaBuilder, diff --git a/src/sagemaker/serve/utils/telemetry_logger.py b/src/sagemaker/serve/utils/telemetry_logger.py index 64cbce03e8..6e7db9043b 100644 --- a/src/sagemaker/serve/utils/telemetry_logger.py +++ b/src/sagemaker/serve/utils/telemetry_logger.py @@ -19,8 +19,22 @@ from sagemaker import Session, exceptions from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH, MLFLOW_TRACKING_ARN from sagemaker.serve.utils.exceptions import ModelBuilderException -from sagemaker.serve.utils.types import ModelServer, ImageUriOption +from sagemaker.serve.utils.lineage_constants import ( + MLFLOW_LOCAL_PATH, + MLFLOW_S3_PATH, + MLFLOW_MODEL_PACKAGE_PATH, + MLFLOW_RUN_ID, + MLFLOW_REGISTRY_PATH, +) +from sagemaker.serve.utils.lineage_utils import _get_mlflow_model_path_type +from sagemaker.serve.utils.types import ( + ModelServer, + ImageUriOption, + ModelHub, + SpeculativeDecodingDraftModelSource, +) from sagemaker.serve.validations.check_image_uri import is_1p_image_uri from sagemaker.user_agent import SDK_VERSION @@ -49,6 +63,26 @@ str(ModelServer.DJL_SERVING): 4, str(ModelServer.TRITON): 5, str(ModelServer.TGI): 6, + str(ModelServer.TEI): 7, + str(ModelServer.SMD): 8, +} + +MLFLOW_MODEL_PATH_CODE = { + MLFLOW_LOCAL_PATH: 1, + MLFLOW_S3_PATH: 2, + MLFLOW_MODEL_PACKAGE_PATH: 3, + MLFLOW_RUN_ID: 4, + MLFLOW_REGISTRY_PATH: 5, +} + +MODEL_HUB_TO_CODE = { + str(ModelHub.JUMPSTART): 1, + str(ModelHub.HUGGINGFACE): 2, +} + +SD_DRAFT_MODEL_SOURCE_TO_CODE = { + str(SpeculativeDecodingDraftModelSource.SAGEMAKER): 1, + str(SpeculativeDecodingDraftModelSource.CUSTOM): 2, } @@ -61,63 +95,99 @@ def wrapper(self, *args, **kwargs): logger.info(TELEMETRY_OPT_OUT_MESSAGING) response = None caught_ex = None - - image_uri_tail = self.image_uri.split("/")[1] - image_uri_option = _get_image_uri_option(self.image_uri, self._is_custom_image_uri) - extra = ( - f"{func_name}" - f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}" - f"&x-imageTag={image_uri_tail}" - f"&x-sdkVersion={SDK_VERSION}" - f"&x-defaultImageUsage={image_uri_option}" - ) - - if self.model_server == ModelServer.DJL_SERVING or self.model_server == ModelServer.TGI: - extra += f"&x-modelName={self.model}" - - if self.sagemaker_session and self.sagemaker_session.endpoint_arn: - extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}" + status = "1" + failure_reason = None + failure_type = None + extra = f"{func_name}" start_timer = perf_counter() try: response = func(self, *args, **kwargs) - stop_timer = perf_counter() - elapsed = stop_timer - start_timer - extra += f"&x-latency={round(elapsed, 2)}" - if not self.serve_settings.telemetry_opt_out: - _send_telemetry( - "1", - MODE_TO_CODE[str(self.mode)], - self.sagemaker_session, - None, - None, - extra, - ) except ( ModelBuilderException, exceptions.CapacityError, exceptions.UnexpectedStatusException, exceptions.AsyncInferenceError, ) as e: - stop_timer = perf_counter() - elapsed = stop_timer - start_timer - extra += f"&x-latency={round(elapsed, 2)}" - if not self.serve_settings.telemetry_opt_out: - _send_telemetry( - "0", - MODE_TO_CODE[str(self.mode)], - self.sagemaker_session, - str(e), - e.__class__.__name__, - extra, - ) + status = "0" caught_ex = e + failure_reason = str(e) + failure_type = e.__class__.__name__ except Exception as e: # pylint: disable=W0703 - caught_ex = e - finally: - if caught_ex: - raise caught_ex - return response # pylint: disable=W0150 + raise e + + stop_timer = perf_counter() + elapsed = stop_timer - start_timer + + if self.model_server: + extra += f"&x-modelServer={MODEL_SERVER_TO_CODE[str(self.model_server)]}" + + if self.image_uri: + image_uri_option = _get_image_uri_option( + self.image_uri, getattr(self, "_is_custom_image_uri", False) + ) + split_image_uri = self.image_uri.split("/") + if len(split_image_uri) > 1: + extra += f"&x-imageTag={split_image_uri[1]}" + + extra += f"&x-sdkVersion={SDK_VERSION}" + + if self.image_uri: + extra += f"&x-defaultImageUsage={image_uri_option}" + + if self.model_server == ModelServer.DJL_SERVING or self.model_server == ModelServer.TGI: + extra += f"&x-modelName={self.model}" + + if self.sagemaker_session and self.sagemaker_session.endpoint_arn: + extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}" + + if getattr(self, "_is_mlflow_model", False): + mlflow_model_path = self.model_metadata[MLFLOW_MODEL_PATH] + mlflow_model_path_type = _get_mlflow_model_path_type(mlflow_model_path) + extra += f"&x-mlflowModelPathType={MLFLOW_MODEL_PATH_CODE[mlflow_model_path_type]}" + mlflow_model_tracking_server_arn = self.model_metadata.get(MLFLOW_TRACKING_ARN) + if mlflow_model_tracking_server_arn is not None: + extra += f"&x-mlflowTrackingServerArn={mlflow_model_tracking_server_arn}" + + if getattr(self, "model_hub", False): + extra += f"&x-modelHub={MODEL_HUB_TO_CODE[str(self.model_hub)]}" + + if getattr(self, "is_fine_tuned", False): + extra += "&x-fineTuned=1" + + if getattr(self, "is_compiled", False): + extra += "&x-compiled=1" + if getattr(self, "is_quantized", False): + extra += "&x-quantized=1" + if getattr(self, "speculative_decoding_draft_model_source", False): + model_provider_enum = ( + SpeculativeDecodingDraftModelSource.SAGEMAKER + if self.speculative_decoding_draft_model_source == "sagemaker" + else SpeculativeDecodingDraftModelSource.CUSTOM + ) + model_provider_value = SD_DRAFT_MODEL_SOURCE_TO_CODE[str(model_provider_enum)] + extra += f"&x-sdDraftModelSource={model_provider_value}" + + if getattr(self, "deployment_config_name", False): + config_name_code = self.deployment_config_name.lower() + extra += f"&x-configName={config_name_code}" + + extra += f"&x-latency={round(elapsed, 2)}" + + if hasattr(self, "serve_settings") and not self.serve_settings.telemetry_opt_out: + _send_telemetry( + status, + MODE_TO_CODE[str(self.mode)], + self.sagemaker_session, + failure_reason, + failure_type, + extra, + ) + + if caught_ex: + raise caught_ex + + return response return wrapper diff --git a/src/sagemaker/serve/utils/tuning.py b/src/sagemaker/serve/utils/tuning.py index 22f3c06d47..5a63cfe508 100644 --- a/src/sagemaker/serve/utils/tuning.py +++ b/src/sagemaker/serve/utils/tuning.py @@ -7,6 +7,7 @@ import collections from multiprocessing.pool import ThreadPool from math import ceil +from typing import Callable import pandas as pd from numpy import percentile, std from sagemaker.serve.model_server.djl_serving.utils import _tokens_from_chars, _tokens_from_words @@ -33,8 +34,8 @@ def _pretty_print_results(results: dict): for key, value in ordered.items(): avg_latencies.append(key) - tensor_parallel_degrees.append(value[0]["option.tensor_parallel_degree"]) - dtypes.append(value[0]["option.dtype"]) + tensor_parallel_degrees.append(value[0]["TENSOR_PARALLEL_DEGREE"]) + dtypes.append(value[0]["OPTION_DTYPE"]) p90s.append(value[1]) avg_tokens_per_seconds.append(value[2]) throughput_per_seconds.append(value[3]) @@ -152,7 +153,7 @@ def _tokens_per_second(generated_text: str, max_token_length: int, latency: floa return min(est_tokens, max_token_length) / latency -def _timed_invoke(predict: callable, sample_input: object) -> tuple: +def _timed_invoke(predict: Callable, sample_input: object) -> tuple: """Placeholder docstring""" start_timer = perf_counter() response = predict(sample_input) diff --git a/src/sagemaker/serve/utils/types.py b/src/sagemaker/serve/utils/types.py index 661093f249..b405d85b21 100644 --- a/src/sagemaker/serve/utils/types.py +++ b/src/sagemaker/serve/utils/types.py @@ -18,18 +18,8 @@ def __str__(self): DJL_SERVING = 4 TRITON = 5 TGI = 6 - - -class _DjlEngine(Enum): - """An enum for Djl Engines""" - - def __str__(self): - """Placeholder docstring""" - return str(self.name) - - DEEPSPEED = 1 - FASTER_TRANSFORMER = 2 - HUGGINGFACE_ACCELERATE = 3 + TEI = 7 + SMD = 8 class HardwareType(Enum): @@ -56,3 +46,25 @@ def __str__(self) -> str: CUSTOM_IMAGE = 1 CUSTOM_1P_IMAGE = 2 DEFAULT_IMAGE = 3 + + +class ModelHub(Enum): + """Enum type for model hub source""" + + def __str__(self) -> str: + """Convert enum to string""" + return str(self.name) + + JUMPSTART = 1 + HUGGINGFACE = 2 + + +class SpeculativeDecodingDraftModelSource(Enum): + """Enum type for speculative decoding draft model source""" + + def __str__(self) -> str: + """Convert enum to string""" + return str(self.name) + + SAGEMAKER = 1 + CUSTOM = 2 diff --git a/src/sagemaker/serve/validations/optimization.py b/src/sagemaker/serve/validations/optimization.py new file mode 100644 index 0000000000..58ef167039 --- /dev/null +++ b/src/sagemaker/serve/validations/optimization.py @@ -0,0 +1,229 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds the validation logic used for the .optimize() function. INTERNAL only""" +from __future__ import absolute_import + +import textwrap +import logging +from typing import Any, Dict, Set, Optional +from enum import Enum +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class _OptimizationContainer(Enum): + """Optimization containers""" + + TRT = "TRT" + VLLM = "vLLM" + NEURON = "Neuron" + + +class _OptimizationCombination(BaseModel): + """Optimization ruleset data structure for comparing input to ruleset""" + + optimization_container: _OptimizationContainer = None + compilation: Set[Optional[bool]] + speculative_decoding: Set[Optional[bool]] + sharding: Set[Optional[bool]] + quantization_technique: Set[Optional[str]] + + def validate_against(self, optimization_combination, rule_set: _OptimizationContainer): + """Validator for optimization containers""" + + # check the validity of each individual field + if not optimization_combination.compilation.issubset(self.compilation): + raise ValueError("Compilation") + if not optimization_combination.quantization_technique.issubset( + self.quantization_technique + ): + copy_quantization_technique = optimization_combination.quantization_technique.copy() + raise ValueError(f"Quantization:{copy_quantization_technique.pop()}") + if not optimization_combination.speculative_decoding.issubset(self.speculative_decoding): + raise ValueError("Speculative Decoding") + if not optimization_combination.sharding.issubset(self.sharding): + raise ValueError("Sharding") + + # optimization technique combinations that need to be validated + if optimization_combination.compilation and optimization_combination.speculative_decoding: + is_compiled = optimization_combination.compilation.copy().pop() + is_speculative_decoding = optimization_combination.speculative_decoding.copy().pop() + if is_compiled and is_speculative_decoding: + raise ValueError("Compilation and Speculative Decoding together") + + if rule_set == _OptimizationContainer.TRT: + is_compiled = optimization_combination.compilation.copy().pop() + is_quantized = optimization_combination.quantization_technique.copy().pop() + if is_quantized and not is_compiled: + raise ValueError(f"Quantization:{is_quantized} must be provided with Compilation") + + +TRUTHY_SET = {None, True} +FALSY_SET = {None, False} +TRT_CONFIGURATION = { + "supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"}, + "optimization_combination": _OptimizationCombination( + optimization_container=_OptimizationContainer.TRT, + compilation=TRUTHY_SET, + quantization_technique={None, "awq", "fp8", "smoothquant"}, + speculative_decoding=FALSY_SET, + sharding=FALSY_SET, + ), +} +VLLM_CONFIGURATION = { + "supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"}, + "optimization_combination": _OptimizationCombination( + optimization_container=_OptimizationContainer.VLLM, + compilation=FALSY_SET, + quantization_technique={None, "awq", "fp8"}, + speculative_decoding=TRUTHY_SET, + sharding=TRUTHY_SET, + ), +} +NEURON_CONFIGURATION = { + "supported_instance_families": {"inf2", "trn1", "trn1n"}, + "optimization_combination": _OptimizationCombination( + optimization_container=_OptimizationContainer.NEURON, + compilation=TRUTHY_SET, + quantization_technique={None}, + speculative_decoding=FALSY_SET, + sharding=FALSY_SET, + ), +} + + +def _validate_optimization_configuration( + is_jumpstart: bool, + instance_type: str, + quantization_config: Dict[str, Any], + compilation_config: Dict[str, Any], + sharding_config: Dict[str, Any], + speculative_decoding_config: Dict[str, Any], +): + """Validate .optimize() input off of standard ruleset""" + + instance_family = None + if instance_type: + split_instance_type = instance_type.split(".") + if len(split_instance_type) == 3: + instance_family = split_instance_type[1] + + if ( + instance_family not in TRT_CONFIGURATION["supported_instance_families"] + and instance_family not in VLLM_CONFIGURATION["supported_instance_families"] + and instance_family not in NEURON_CONFIGURATION["supported_instance_families"] + ): + invalid_instance_type_msg = ( + f"Optimizations that uses {instance_type} instance type are " + "not currently supported both on GPU and Neuron instances" + ) + raise ValueError(invalid_instance_type_msg) + + quantization_technique = None + if ( + quantization_config + and quantization_config.get("OverrideEnvironment") + and quantization_config.get("OverrideEnvironment").get("OPTION_QUANTIZE") + ): + quantization_technique = quantization_config.get("OverrideEnvironment").get( + "OPTION_QUANTIZE" + ) + + optimization_combination = _OptimizationCombination( + compilation={None if compilation_config is None else True}, + speculative_decoding={None if speculative_decoding_config is None else True}, + sharding={None if sharding_config is None else True}, + quantization_technique={quantization_technique}, + ) + + # Check the case where no optimization combination is provided + if ( + optimization_combination.compilation == {None} + and optimization_combination.quantization_technique == {None} + and optimization_combination.speculative_decoding == {None} + and optimization_combination.sharding == {None} + ): + # JumpStart has defaults for Inf/Trn instances + if is_jumpstart and instance_family in NEURON_CONFIGURATION["supported_instance_families"]: + return + raise ValueError( + ( + "Optimizations that provide no optimization configs " + "are currently not support on both GPU and Neuron instances." + ) + ) + + # Validate based off of instance type + if instance_family in NEURON_CONFIGURATION["supported_instance_families"]: + try: + ( + NEURON_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.NEURON + ) + ) + except ValueError as neuron_compare_error: + raise ValueError( + ( + f"Optimizations that use {neuron_compare_error} " + "are not supported on Neuron instances." + ) + ) + else: + if optimization_combination.compilation.copy().pop(): # Compilation is only enabled for TRT + try: + TRT_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.TRT + ) + except ValueError as trt_compare_error: + raise ValueError( + ( + f"Optimizations that use Compilation and {trt_compare_error} " + "are not supported for GPU instances." + ) + ) + else: + try: + ( + VLLM_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.VLLM + ) + ) + except ValueError as vllm_compare_error: + try: # try both VLLM and TRT to cover both rule sets + ( + TRT_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.TRT + ) + ) + except ValueError as trt_compare_error: + if ( + str(trt_compare_error) + == "Quantization:smoothquant must be provided with Compilation" + ): + raise ValueError( + f"Optimizations that use {trt_compare_error} for GPU instances." + ) + if str(trt_compare_error) == str(vllm_compare_error): + raise ValueError( + ( + f"Optimizations that use {trt_compare_error} " + "are not supported for GPU instances." + ) + ) + joint_error_msg = f""" + Optimization cannot be performed for the following reasons: + - Optimizations that use {trt_compare_error} are not supported for GPU instances. + - Optimizations that use {vllm_compare_error} are not supported for GPU instances. + """ + raise ValueError(textwrap.dedent(joint_error_msg)) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 5ea3d5f8a1..705d9892fe 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -121,7 +121,7 @@ from sagemaker.deprecations import deprecated_class from sagemaker.enums import EndpointType from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig -from sagemaker.user_agent import prepend_user_agent +from sagemaker.user_agent import get_user_agent_extra_suffix from sagemaker.utils import ( name_from_image, secondary_training_status_changed, @@ -285,6 +285,7 @@ def _initialize( Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client. Sets the region_name. """ + self.boto_session = boto_session or boto3.DEFAULT_SESSION or boto3.Session() self._region_name = self.boto_session.region_name @@ -293,19 +294,30 @@ def _initialize( "Must setup local AWS configuration with a region supported by SageMaker." ) - self.sagemaker_client = sagemaker_client or self.boto_session.client("sagemaker") - prepend_user_agent(self.sagemaker_client) + # Make use of user_agent_extra field of the botocore_config object + # to append SageMaker Python SDK specific user_agent suffix + # to the current User-Agent header value from boto3 + # This config will also make sure that user_agent never fails to log the User-Agent string + # even if boto User-Agent header format is updated in the future + # Ref: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + botocore_config = botocore.config.Config(user_agent_extra=get_user_agent_extra_suffix()) + + # Create sagemaker_client with the botocore_config object + # This config is customized to append SageMaker Python SDK specific user_agent suffix + self.sagemaker_client = sagemaker_client or self.boto_session.client( + "sagemaker", config=botocore_config + ) if sagemaker_runtime_client is not None: self.sagemaker_runtime_client = sagemaker_runtime_client else: - config = botocore.config.Config(read_timeout=80) + config = botocore.config.Config( + read_timeout=80, user_agent_extra=get_user_agent_extra_suffix() + ) self.sagemaker_runtime_client = self.boto_session.client( "runtime.sagemaker", config=config ) - prepend_user_agent(self.sagemaker_runtime_client) - if sagemaker_featurestore_runtime_client: self.sagemaker_featurestore_runtime_client = sagemaker_featurestore_runtime_client else: @@ -316,8 +328,9 @@ def _initialize( if sagemaker_metrics_client: self.sagemaker_metrics_client = sagemaker_metrics_client else: - self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics") - prepend_user_agent(self.sagemaker_metrics_client) + self.sagemaker_metrics_client = self.boto_session.client( + "sagemaker-metrics", config=botocore_config + ) self.s3_client = self.boto_session.client("s3", region_name=self.boto_region_name) self.s3_resource = self.boto_session.resource("s3", region_name=self.boto_region_name) @@ -618,43 +631,78 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): bucket = s3.Bucket(name=bucket_name) if bucket.creation_date is None: - try: - # trying head bucket call - s3.meta.client.head_bucket(Bucket=bucket.name) - except ClientError as e: - # bucket does not exist or forbidden to access - error_code = e.response["Error"]["Code"] - message = e.response["Error"]["Message"] + self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True) - if error_code == "404" and message == "Not Found": - # bucket does not exist, create one - try: - if region == "us-east-1": - # 'us-east-1' cannot be specified because it is the default region: - # https://github.com/boto/boto3/issues/125 - s3.create_bucket(Bucket=bucket_name) - else: - s3.create_bucket( - Bucket=bucket_name, - CreateBucketConfiguration={"LocationConstraint": region}, - ) + elif self._default_bucket_set_by_sdk: + self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False) + expected_bucket_owner_id = self.account_id() + self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) + + def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id): + """Checks if the bucket belongs to a particular owner and throws a Client Error if it is not - logger.info("Created S3 bucket: %s", bucket_name) - except ClientError as e: - error_code = e.response["Error"]["Code"] - message = e.response["Error"]["Message"] - - if ( - error_code == "OperationAborted" - and "conflicting conditional operation" in message - ): - # If this bucket is already being concurrently created, - # we don't need to create it again. - pass - else: - raise + Args: + bucket_name (str): Name of the S3 bucket + s3 (str): S3 object from boto session + expected_bucket_owner_id (str): Owner ID string + + """ + try: + if self.default_bucket_prefix: + s3.meta.client.list_objects_v2( + Bucket=bucket_name, + Prefix=self.default_bucket_prefix, + ExpectedBucketOwner=expected_bucket_owner_id, + ) + else: + s3.meta.client.head_bucket( + Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id + ) + except ClientError as e: + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] + if error_code == "403" and message == "Forbidden": + LOGGER.error( + "Since default_bucket param was not set, SageMaker Python SDK tried to use " + "%s bucket. " + "This bucket cannot be configured to use as it is not owned by Account %s. " + "To unblock it's recommended to use custom default_bucket " + "parameter in sagemaker.Session", + bucket_name, + expected_bucket_owner_id, + ) + raise + + def general_bucket_check_if_user_has_permission( + self, bucket_name, s3, bucket, region, bucket_creation_date_none + ): + """Checks if the person running has the permissions to the bucket + + If there is any other error that comes up with calling head bucket, it is raised up here + If there is no bucket , it will create one + + Args: + bucket_name (str): Name of the S3 bucket + s3 (str): S3 object from boto session + region (str): The region in which to create the bucket. + bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not + """ + try: + if self.default_bucket_prefix: + s3.meta.client.list_objects_v2( + Bucket=bucket_name, Prefix=self.default_bucket_prefix + ) + else: + s3.meta.client.head_bucket(Bucket=bucket_name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] + # bucket does not exist or forbidden to access + if bucket_creation_date_none: + if error_code == "404" and message == "Not Found": + self.create_bucket_for_not_exist_error(bucket_name, region, s3) elif error_code == "403" and message == "Forbidden": - logger.error( + LOGGER.error( "Bucket %s exists, but access is forbidden. Please try again after " "adding appropriate access.", bucket.name, @@ -663,27 +711,37 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): else: raise - if self._default_bucket_set_by_sdk: - # make sure the s3 bucket is configured in users account. - expected_bucket_owner_id = self.account_id() - try: - s3.meta.client.head_bucket( - Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id + def create_bucket_for_not_exist_error(self, bucket_name, region, s3): + """Creates the S3 bucket in the given region + + Args: + bucket_name (str): Name of the S3 bucket + s3 (str): S3 object from boto session + region (str): The region in which to create the bucket. + """ + # bucket does not exist, create one + try: + if region == "us-east-1": + # 'us-east-1' cannot be specified because it is the default region: + # https://github.com/boto/boto3/issues/125 + s3.create_bucket(Bucket=bucket_name) + else: + s3.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": region}, ) - except ClientError as e: - error_code = e.response["Error"]["Code"] - message = e.response["Error"]["Message"] - if error_code == "403" and message == "Forbidden": - LOGGER.error( - "Since default_bucket param was not set, SageMaker Python SDK tried to use " - "%s bucket. " - "This bucket cannot be configured to use as it is not owned by Account %s. " - "To unblock it's recommended to use custom default_bucket " - "parameter in sagemaker.Session", - bucket_name, - expected_bucket_owner_id, - ) - raise + + logger.info("Created S3 bucket: %s", bucket_name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] + + if error_code == "OperationAborted" and "conflicting conditional operation" in message: + # If this bucket is already being concurrently created, + # we don't need to create it again. + pass + else: + raise def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tags: str): """Appends tags specified in the sagemaker_config to the given list of tags. @@ -724,7 +782,7 @@ def _append_sagemaker_config_tags(self, tags: List[TagsDict], config_path_to_tag return all_tags - def train( # noqa: C901 + def get_train_request( self, input_mode, input_config, @@ -759,7 +817,7 @@ def train( # noqa: C901 retry_strategy=None, remote_debug_config=None, session_chaining_config=None, - ): + ) -> Dict: """Create an Amazon SageMaker training job. Args: @@ -902,7 +960,7 @@ def train( # noqa: C901 "EnableInfraCheck": True, } Returns: - str: ARN of the training job, if it is created. + Dict: a Dict containing CreateTrainingJob request. """ tags = _append_project_tags(format_tags(tags)) tags = self._append_sagemaker_config_tags( @@ -984,11 +1042,243 @@ def train( # noqa: C901 environment=environment, retry_strategy=retry_strategy, ) + return train_request + + def train( # noqa: C901 + self, + input_mode, + input_config, + role=None, + job_name=None, + output_config=None, + resource_config=None, + vpc_config=None, + hyperparameters=None, + stop_condition=None, + tags=None, + metric_definitions=None, + enable_network_isolation=None, + image_uri=None, + training_image_config=None, + infra_check_config=None, + container_entry_point=None, + container_arguments=None, + algorithm_arn=None, + encrypt_inter_container_traffic=None, + use_spot_instances=False, + checkpoint_s3_uri=None, + checkpoint_local_path=None, + experiment_config=None, + debugger_rule_configs=None, + debugger_hook_config=None, + tensorboard_output_config=None, + enable_sagemaker_metrics=None, + profiler_rule_configs=None, + profiler_config=None, + environment: Optional[Dict[str, str]] = None, + retry_strategy=None, + remote_debug_config=None, + session_chaining_config=None, + ): + """Create an Amazon SageMaker training job. + + Args: + input_mode (str): The input mode that the algorithm supports. Valid modes: + * 'File' - Amazon SageMaker copies the training dataset from the S3 location to + a directory in the Docker container. + * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a + Unix-named pipe. + * 'FastFile' - Amazon SageMaker streams data from S3 on demand instead of + downloading the entire dataset before training begins. + input_config (list): A list of Channel objects. Each channel is a named input source. + Please refer to the format details described: + https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job + role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training + jobs and APIs that create Amazon SageMaker endpoints use this role to access + training data and model artifacts. You must grant sufficient permissions to this + role. + job_name (str): Name of the training job being created. + output_config (dict): The S3 URI where you want to store the training results and + optional KMS key ID. + resource_config (dict): Contains values for ResourceConfig: + * instance_count (int): Number of EC2 instances to use for training. + The key in resource_config is 'InstanceCount'. + * instance_type (str): Type of EC2 instance to use for training, for example, + 'ml.c4.xlarge'. The key in resource_config is 'InstanceType'. + vpc_config (dict): Contains values for VpcConfig: + * subnets (list[str]): List of subnet ids. + The key in vpc_config is 'Subnets'. + * security_group_ids (list[str]): List of security group ids. + The key in vpc_config is 'SecurityGroupIds'. + hyperparameters (dict): Hyperparameters for model training. The hyperparameters are + made accessible as a dict[str, str] to the training code on SageMaker. For + convenience, this accepts other types for keys and values, but ``str()`` will be + called to convert them before training. + stop_condition (dict): Defines when training shall finish. Contains entries that can + be understood by the service like ``MaxRuntimeInSeconds``. + tags (Optional[Tags]): Tags for labeling a training job. For more, see + https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) + used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for + the name of the metric, and 'Regex' for the regular expression used to extract the + metric from the logs. + enable_network_isolation (bool): Whether to request for the training job to run with + network isolation or not. + image_uri (str): Docker image containing training code. + training_image_config(dict): Training image configuration. + Optionally, the dict can contain 'TrainingRepositoryAccessMode' and + 'TrainingRepositoryCredentialsProviderArn' (under 'TrainingRepositoryAuthConfig'). + For example, + + .. code:: python + + training_image_config = { + "TrainingRepositoryAccessMode": "Vpc", + "TrainingRepositoryAuthConfig": { + "TrainingRepositoryCredentialsProviderArn": + "arn:aws:lambda:us-west-2:1234567890:function:test" + }, + } + + If TrainingRepositoryAccessMode is set to Vpc, the training image is accessed + through a private Docker registry in customer Vpc. If it's set to Platform or None, + the training image is accessed through ECR. + If TrainingRepositoryCredentialsProviderArn is provided, the credentials to + authenticate to the private Docker registry will be retrieved from this AWS Lambda + function. (default: ``None``). When it's set to None, SageMaker will not do + authentication before pulling the image in the private Docker registry. + container_entry_point (List[str]): Optional. The entrypoint script for a Docker + container used to run a training job. This script takes precedence over + the default train processing instructions. + container_arguments (List[str]): Optional. The arguments for a container used to run + a training job. + algorithm_arn (str): Algorithm Arn from Marketplace. + encrypt_inter_container_traffic (bool): Specifies whether traffic between training + containers is encrypted for the training job (default: ``False``). + use_spot_instances (bool): whether to use spot instances for training. + checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints + that the algorithm persists (if any) during training. (default: + ``None``). + checkpoint_local_path (str): The local path that the algorithm + writes its checkpoints to. SageMaker will persist all files + under this path to `checkpoint_s3_uri` continually during + training. On job startup the reverse happens - data from the + s3 location is downloaded to this path before the algorithm is + started. If the path is unset then SageMaker assumes the + checkpoints will be provided under `/opt/ml/checkpoints/`. + (default: ``None``). + experiment_config (dict[str, str]): Experiment management configuration. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. + The behavior of setting these keys is as follows: + * If `ExperimentName` is supplied but `TrialName` is not a Trial will be + automatically created and the job's Trial Component associated with the Trial. + * If `TrialName` is supplied and the Trial already exists the job's Trial Component + will be associated with the Trial. + * If both `ExperimentName` and `TrialName` are not supplied the trial component + will be unassociated. + * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. + enable_sagemaker_metrics (bool): enable SageMaker Metrics Time + Series. For more information see: + https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html + #SageMaker-Type + -AlgorithmSpecification-EnableSageMakerMetricsTimeSeries + (default: ``None``). + profiler_rule_configs (list[dict]): A list of profiler rule + configurations.src/sagemaker/lineage/artifact.py:285 + profiler_config (dict): Configuration for how profiling information is emitted + with SageMaker Profiler. (default: ``None``). + remote_debug_config(dict): Configuration for RemoteDebug. (default: ``None``) + The dict can contain 'EnableRemoteDebug'(bool). + For example, + + .. code:: python + + remote_debug_config = { + "EnableRemoteDebug": True, + } + session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``) + The dict can contain 'EnableSessionTagChaining'(bool). + For example, + + .. code:: python + + session_chaining_config = { + "EnableSessionTagChaining": True, + } + environment (dict[str, str]) : Environment variables to be set for + use during training job (default: ``None``) + retry_strategy(dict): Defines RetryStrategy for InternalServerFailures. + * max_retry_attsmpts (int): Number of times a job should be retried. + The key in RetryStrategy is 'MaxRetryAttempts'. + infra_check_config(dict): Infra check configuration. + Optionally, the dict can contain 'EnableInfraCheck'(bool). + For example, + + .. code:: python + + infra_check_config = { + "EnableInfraCheck": True, + } + Returns: + str: ARN of the training job, if it is created. + + Raises: + - botocore.exceptions.ClientError: If Sagemaker throws an exception while creating + training job. + - ValueError: If both image_uri and algorithm are provided, or if neither is provided. + """ + train_request = self.get_train_request( + input_mode, + input_config, + role, + job_name, + output_config, + resource_config, + vpc_config, + hyperparameters, + stop_condition, + tags, + metric_definitions, + enable_network_isolation, + image_uri, + training_image_config, + infra_check_config, + container_entry_point, + container_arguments, + algorithm_arn, + encrypt_inter_container_traffic, + use_spot_instances, + checkpoint_s3_uri, + checkpoint_local_path, + experiment_config, + debugger_rule_configs, + debugger_hook_config, + tensorboard_output_config, + enable_sagemaker_metrics, + profiler_rule_configs, + profiler_config, + environment, + retry_strategy, + remote_debug_config, + session_chaining_config, + ) def submit(request): - logger.info("Creating training-job with name: %s", job_name) - logger.debug("train request: %s", json.dumps(request, indent=4)) - self.sagemaker_client.create_training_job(**request) + try: + logger.info("Creating training-job with name: %s", job_name) + logger.debug("train request: %s", json.dumps(request, indent=4)) + self.sagemaker_client.create_training_job(**request) + except Exception as e: + troubleshooting = ( + "https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-python-sdk-troubleshooting.html" + "#sagemaker-python-sdk-troubleshooting-create-training-job" + ) + logger.error( + "Please check the troubleshooting guide for common errors: %s", troubleshooting + ) + raise e self._intercept_create_request(train_request, submit, self.train.__name__) @@ -1295,6 +1585,15 @@ def update_training_job( remote_debug_config = { "EnableRemoteDebug": True, } + + Returns: + str: ARN of training job + + Raises: + - botocore.exceptions.ClientError: If Sagemaker throws an error while updating training + job. + - botocore.exceptions.ParamValidationError: If any request parameters are in an invalid + format. """ # No injections from sagemaker_config because the UpdateTrainingJob API's resource_config # object accepts fewer parameters than the CreateTrainingJob API, and none that the @@ -1309,9 +1608,28 @@ def update_training_job( resource_config=resource_config, remote_debug_config=remote_debug_config, ) - logger.info("Updating training job with name %s", job_name) - logger.debug("Update request: %s", json.dumps(update_training_job_request, indent=4)) - self.sagemaker_client.update_training_job(**update_training_job_request) + try: + logger.info("Updating training job with name %s", job_name) + logger.debug("Update request: %s", json.dumps(update_training_job_request, indent=4)) + self.sagemaker_client.update_training_job(**update_training_job_request) + except botocore.exceptions.ParamValidationError as e: + troubleshooting = ( + "Incorrect request parameter was provided. Check the API documentation: " + "https://docs.aws.amazon.com/sagemaker/latest/APIReference/" + "API_UpdateTrainingJob.html#API_UpdateTrainingJob_RequestParameters" + ) + logger.error("%s", troubleshooting) + raise e + except botocore.exceptions.ClientError as e: + troubleshooting = ( + "https://docs.aws.amazon.com/sagemaker/latest/dg/" + "sagemaker-python-sdk-troubleshooting.html" + "#sagemaker-python-sdk-troubleshooting-update-training-job" + ) + logger.error( + "Please check the troubleshooting guide for common errors: %s", troubleshooting + ) + raise e def _get_update_training_job_request( self, @@ -1414,6 +1732,10 @@ def process( * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + + Raises: + - botocore.exceptions.ClientError: If Sagemaker throws an error while creating + processing job. """ tags = _append_project_tags(format_tags(tags)) tags = self._append_sagemaker_config_tags( @@ -1477,9 +1799,20 @@ def process( ) def submit(request): - logger.info("Creating processing-job with name %s", job_name) - logger.debug("process request: %s", json.dumps(request, indent=4)) - self.sagemaker_client.create_processing_job(**request) + try: + logger.info("Creating processing-job with name %s", job_name) + logger.debug("process request: %s", json.dumps(request, indent=4)) + self.sagemaker_client.create_processing_job(**request) + except Exception as e: + troubleshooting = ( + "https://docs.aws.amazon.com/sagemaker/latest/dg/" + "sagemaker-python-sdk-troubleshooting.html" + "#sagemaker-python-sdk-troubleshooting-create-processing-job" + ) + logger.error( + "Please check the troubleshooting guide for common errors: %s", troubleshooting + ) + raise e self._intercept_create_request(process_request, submit, self.process.__name__) @@ -2365,6 +2698,75 @@ def describe_training_job(self, job_name): """ return self.sagemaker_client.describe_training_job(TrainingJobName=job_name) + def describe_training_plan(self, training_plan_name): + """Calls the DescribeTrainingPlan API for the given training plan and returns the response. + + Args: + training_plan_name (str): The name of the training plan to describe. + + Returns: + dict: A dictionary response with the training plan description. + """ + return self.sagemaker_client.describe_training_plan(TrainingPlanName=training_plan_name) + + def list_training_plans( + self, + filters=None, + requested_start_time_after=None, + requested_start_time_before=None, + start_time_after=None, + start_time_before=None, + sort_order=None, + sort_by=None, + max_results=None, + next_token=None, + ): + """Calls the ListrTrainingPlan API for the given filters and returns the response. + + Args: + filters (dict): A dictionary of key-value pairs used to filter the training plans. + Default to None. + requested_start_time_after (datetime): A timestamp that filters the results + to only include training plans with a requested start time after this timestamp. + requested_start_time_before (datetime): A timestamp that filters the results + to only include training plans with a requested start time before this timestamp. + start_time_after (datetime): A timestamp that filters the results + to only include training plans with an actual start time after this timestamp. + start_time_before (datetime): A timestamp that filters the results + to only include training plans with an actual start time before this timestamp. + sort_order (str): The order that the training plans will be listed in result. + Default to None. + sort_by (str): The value that the training plans will be sorted by. + Default to None. + max_results (int): The number of candidates will be listed in results, + between 1 and 100. Default to None. If None, will return all the training_plans. + next_token (str): The pagination token. Default to None. + + Returns: + dict: A dictionary containing the following keys: + - "TrainingPlanSummaries": A list of dictionaries, where each dictionary represents + a training plan. + - "NextToken": A token to retrieve the next set of results, if there are more + than the maximum number of results returned. + """ + list_training_plan_args = {} + + def check_object(key, value): + if value is not None: + list_training_plan_args[key] = value + + check_object("Filters", filters) + check_object("SortBy", sort_by) + check_object("SortOrder", sort_order) + check_object("RequestedStartTimeAfter", requested_start_time_after) + check_object("RequestedStartTimeBefore", requested_start_time_before) + check_object("StartTimeAfter", start_time_after) + check_object("StartTimeBefore", start_time_before) + check_object("NextToken", next_token) + check_object("MaxResults", max_results) + + return self.sagemaker_client.list_training_plans(**list_training_plan_args) + def auto_ml( self, input_config, @@ -2578,6 +2980,24 @@ def wait_for_auto_ml_job(self, job, poll=5): _check_job_status(job, desc, "AutoMLJobStatus") return desc + def wait_for_optimization_job(self, job, poll=5): + """Wait for an Amazon SageMaker Optimization job to complete. + + Args: + job (str): Name of optimization job to wait for. + poll (int): Polling interval in seconds (default: 5). + + Returns: + (dict): Return value from the ``DescribeOptimizationJob`` API. + + Raises: + exceptions.ResourceNotFound: If optimization job fails with CapacityError. + exceptions.UnexpectedStatusException: If optimization job fails. + """ + desc = _wait_until(lambda: _optimization_job_status(self.sagemaker_client, job), poll) + _check_job_status(job, desc, "OptimizationJobStatus") + return desc + def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this method self, job_name, wait=False, poll=10 ): @@ -3128,6 +3548,9 @@ def tune( # noqa: C901 tune_request["Autotune"] = {"Mode": "Enabled"} tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, TRAINING_JOB, TAGS) + ) if tags is not None: tune_request["Tags"] = tags @@ -3239,6 +3662,9 @@ def _get_tuning_request( tune_request["WarmStartConfig"] = warm_start_config tags = _append_project_tags(format_tags(tags)) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, TRAINING_JOB, TAGS) + ) if tags is not None: tune_request["Tags"] = tags @@ -4042,6 +4468,8 @@ def create_model_package_from_containers( task=None, skip_model_validation="None", source_uri=None, + model_card=None, + model_life_cycle=None, ): """Get request dictionary for CreateModelPackage API. @@ -4079,6 +4507,9 @@ def create_model_package_from_containers( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). """ if containers: # Containers are provided. Now we can merge missing entries from config. @@ -4136,17 +4567,67 @@ def create_model_package_from_containers( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + model_life_cycle=model_life_cycle, ) def submit(request): if model_package_group_name is not None and not model_package_group_name.startswith( "arn:" ): - _create_resource( - lambda: self.sagemaker_client.create_model_package_group( - ModelPackageGroupName=request["ModelPackageGroupName"] + is_model_package_group_present = False + try: + model_package_groups_response = self.search( + resource="ModelPackageGroup", + search_expression={ + "Filters": [ + { + "Name": "ModelPackageGroupName", + "Value": request["ModelPackageGroupName"], + "Operator": "Equals", + } + ], + }, + ) + if len(model_package_groups_response.get("Results")) > 0: + is_model_package_group_present = True + except Exception: # pylint: disable=W0703 + model_package_groups = [] + model_package_groups_response = self.sagemaker_client.list_model_package_groups( + NameContains=request["ModelPackageGroupName"], + ) + model_package_groups = ( + model_package_groups + + model_package_groups_response["ModelPackageGroupSummaryList"] + ) + next_token = model_package_groups_response.get("NextToken") + + while next_token is not None and next_token != "": + model_package_groups_response = ( + self.sagemaker_client.list_model_package_groups( + NameContains=request["ModelPackageGroupName"], NextToken=next_token + ) + ) + model_package_groups = ( + model_package_groups + + model_package_groups_response["ModelPackageGroupSummaryList"] + ) + next_token = model_package_groups_response.get("NextToken") + + filtered_model_package_group = list( + filter( + lambda mpg: mpg.get("ModelPackageGroupName") + == request["ModelPackageGroupName"], + model_package_groups, + ) + ) + is_model_package_group_present = len(filtered_model_package_group) > 0 + if not is_model_package_group_present: + _create_resource( + lambda: self.sagemaker_client.create_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) ) - ) if "SourceUri" in request and request["SourceUri"] is not None: # Remove inference spec from request if the # given source uri can lead to auto-population of it @@ -4210,6 +4691,49 @@ def wait_for_model_package(self, model_package_name, poll=5): ) return desc + def get_most_recently_created_approved_model_package(self, model_package_group_name): + """Returns the most recently created and Approved model package in a model package group + + Args: + model_package_group_name (str): Name or Arn of the model package group + + Returns: + dict: Returns a "sagemaker.model.ModelPackage" value. + """ + + approved_model_packages = self.sagemaker_client.list_model_packages( + ModelPackageGroupName=model_package_group_name, + ModelApprovalStatus="Approved", + SortBy="CreationTime", + SortOrder="Descending", + MaxResults=1, + ) + next_token = approved_model_packages.get("NextToken") + + while ( + len(approved_model_packages.get("ModelPackageSummaryList")) == 0 + and next_token is not None + and next_token != "" + ): + approved_model_packages = self.sagemaker_client.list_model_packages( + ModelPackageGroupName=model_package_group_name, + ModelApprovalStatus="Approved", + SortBy="CreationTime", + SortOrder="Descending", + MaxResults=1, + NextToken=next_token, + ) + next_token = approved_model_packages.get("NextToken") + + if len(approved_model_packages.get("ModelPackageSummaryList")) == 0: + return None + + return sagemaker.model.ModelPackage( + model_package_arn=approved_model_packages.get("ModelPackageSummaryList")[0].get( + "ModelPackageArn" + ) + ) + def describe_model(self, name): """Calls the DescribeModel API for the given model name. @@ -4235,6 +4759,10 @@ def create_endpoint_config( model_data_download_timeout=None, container_startup_health_check_timeout=None, explainer_config_dict=None, + async_inference_config_dict=None, + serverless_inference_config_dict=None, + routing_config: Optional[Dict[str, Any]] = None, + inference_ami_version: Optional[str] = None, ): """Create an Amazon SageMaker endpoint configuration. @@ -4272,6 +4800,30 @@ def create_endpoint_config( -inference-algo-ping-requests explainer_config_dict (dict): Specifies configuration to enable explainers. Default: None. + async_inference_config_dict (dict): Specifies + configuration related to async endpoint. Use this configuration when trying + to create async endpoint and make async inference. If empty config object + passed through, will use default config to deploy async endpoint. Deploy a + real-time endpoint if it's None. (default: None). + serverless_inference_config_dict (dict): + Specifies configuration related to serverless endpoint. Use this configuration + when trying to create serverless endpoint and make serverless inference. If + empty object passed through, will use pre-defined values in + ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an + instance based endpoint if it's None. (default: None). + routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes + incoming traffic to the instances that the endpoint hosts. + Currently, support dictionary key ``RoutingStrategy``. + + .. code:: python + + { + "RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM + } + inference_ami_version (Optional [str]): + Specifies an option from a collection of preconfigured + Amazon Machine Image (AMI) images. For a full list of options, see: + https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] @@ -4291,9 +4843,12 @@ def create_endpoint_config( instance_type, initial_instance_count, accelerator_type=accelerator_type, + serverless_inference_config=serverless_inference_config_dict, volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, + routing_config=routing_config, + inference_ami_version=inference_ami_version, ) production_variants = [provided_production_variant] # Currently we just inject CoreDumpConfig.KmsKeyId from the config for production variant. @@ -4333,6 +4888,14 @@ def create_endpoint_config( ) request["DataCaptureConfig"] = inferred_data_capture_config_dict + if async_inference_config_dict is not None: + inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config( + async_inference_config_dict, + ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH, + sagemaker_session=self, + ) + request["AsyncInferenceConfig"] = inferred_async_inference_config_dict + if explainer_config_dict is not None: request["ExplainerConfig"] = explainer_config_dict @@ -4498,6 +5061,10 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True, live Returns: str: Name of the Amazon SageMaker ``Endpoint`` created. + + Raises: + botocore.exceptions.ClientError: If Sagemaker throws an exception while creating + endpoint. """ logger.info("Creating endpoint with name %s", endpoint_name) @@ -4506,16 +5073,26 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True, live tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, ENDPOINT, TAGS) ) - - res = self.sagemaker_client.create_endpoint( - EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags - ) - if res: - self.endpoint_arn = res["EndpointArn"] - - if wait: - self.wait_for_endpoint(endpoint_name, live_logging=live_logging) - return endpoint_name + try: + res = self.sagemaker_client.create_endpoint( + EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags + ) + if res: + self.endpoint_arn = res["EndpointArn"] + + if wait: + self.wait_for_endpoint(endpoint_name, live_logging=live_logging) + return endpoint_name + except Exception as e: + troubleshooting = ( + "https://docs.aws.amazon.com/sagemaker/latest/dg/" + "sagemaker-python-sdk-troubleshooting.html" + "#sagemaker-python-sdk-troubleshooting-create-endpoint" + ) + logger.error( + "Please check the troubleshooting guide for common errors: %s", troubleshooting + ) + raise e def endpoint_in_service_or_not(self, endpoint_name: str): """Check whether an Amazon SageMaker ``Endpoint``` is in IN_SERVICE status. @@ -4560,7 +5137,9 @@ def update_endpoint(self, endpoint_name, endpoint_config_name, wait=True): str: Name of the Amazon SageMaker ``Endpoint`` being updated. Raises: - ValueError: if the endpoint does not already exist + - ValueError: if the endpoint does not already exist + - botocore.exceptions.ClientError: If SageMaker throws an error while + creating endpoint config, describing endpoint or updating endpoint """ if not _deployment_entity_exists( lambda: self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name) @@ -4570,15 +5149,27 @@ def update_endpoint(self, endpoint_name, endpoint_config_name, wait=True): "existing endpoint name".format(endpoint_name) ) - res = self.sagemaker_client.update_endpoint( - EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name - ) - if res: - self.endpoint_arn = res["EndpointArn"] + try: - if wait: - self.wait_for_endpoint(endpoint_name) - return endpoint_name + res = self.sagemaker_client.update_endpoint( + EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name + ) + if res: + self.endpoint_arn = res["EndpointArn"] + + if wait: + self.wait_for_endpoint(endpoint_name) + return endpoint_name + except Exception as e: + troubleshooting = ( + "https://docs.aws.amazon.com/sagemaker/latest/dg/" + "sagemaker-python-sdk-troubleshooting.html" + "#sagemaker-python-sdk-troubleshooting-update-endpoint" + ) + logger.error( + "Please check the troubleshooting guide for common errors: %s", troubleshooting + ) + raise e def is_inference_component_based_endpoint(self, endpoint_name): """Returns 'True' if endpoint is inference-component-based, 'False' otherwise. @@ -4684,7 +5275,7 @@ def create_inference_component( tags = self._append_sagemaker_config_tags( tags, "{}.{}.{}".format(SAGEMAKER, INFERENCE_COMPONENT, TAGS) ) - if len(tags) != 0: + if tags and len(tags) != 0: request["Tags"] = tags self.sagemaker_client.create_inference_component(**request) @@ -4859,7 +5450,7 @@ def update_inference_component( return inference_component_name def delete_inference_component(self, inference_component_name: str, wait: bool = False): - """Deletes a InferenceComponent. + """Deletes an InferenceComponent. Args: inference_component_name (str): Name of the Amazon SageMaker ``InferenceComponent`` @@ -5053,7 +5644,7 @@ def get_tagging_resources(self, tag_filters, resource_type_filters): resource_tag_response = self.resource_group_tagging_client.get_resources( TagFilters=tag_filters, ResourceTypeFilters=resource_type_filters, - NextToken=next_token, + PaginationToken=next_token, ) resource_list = resource_list + resource_tag_response["ResourceTagMappingList"] next_token = resource_tag_response.get("PaginationToken") @@ -6692,6 +7283,323 @@ def wait_for_inference_recommendations_job( _check_job_status(job_name, desc, "Status") return desc + def create_presigned_mlflow_tracking_server_url( + self, + tracking_server_name: str, + expires_in_seconds: int = None, + session_expiration_duration_in_seconds: int = None, + ) -> Dict[str, Any]: + """Creates a Presigned Url to acess the Mlflow UI. + + Args: + tracking_server_name (str): Name of the Mlflow Tracking Server. + expires_in_seconds (int): Expiration duration of the URL. + session_expiration_duration_in_seconds (int): Session duration of the URL. + Returns: + (dict): Return value from the ``CreatePresignedMlflowTrackingServerUrl`` API. + + """ + + create_presigned_url_args = {"TrackingServerName": tracking_server_name} + if expires_in_seconds is not None: + create_presigned_url_args["ExpiresInSeconds"] = expires_in_seconds + + if session_expiration_duration_in_seconds is not None: + create_presigned_url_args["SessionExpirationDurationInSeconds"] = ( + session_expiration_duration_in_seconds + ) + + return self.sagemaker_client.create_presigned_mlflow_tracking_server_url( + **create_presigned_url_args + ) + + def create_hub( + self, + hub_name: str, + hub_description: str, + hub_display_name: str = None, + hub_search_keywords: List[str] = None, + s3_storage_config: Dict[str, Any] = None, + tags: List[Dict[str, Any]] = None, + ) -> Dict[str, str]: + """Creates a SageMaker Hub + + Args: + hub_name (str): The name of the Hub to create. + hub_description (str): A description of the Hub. + hub_display_name (str): The display name of the Hub. + hub_search_keywords (list): The searchable keywords for the Hub. + s3_storage_config (S3StorageConfig): The Amazon S3 storage configuration for the Hub. + tags (list): Any tags to associate with the Hub. + + Returns: + (dict): Return value from the ``CreateHub`` API. + """ + request = {"HubName": hub_name, "HubDescription": hub_description} + + if hub_display_name: + request["HubDisplayName"] = hub_display_name + else: + request["HubDisplayName"] = hub_name + + if hub_search_keywords: + request["HubSearchKeywords"] = hub_search_keywords + if s3_storage_config: + request["S3StorageConfig"] = s3_storage_config + if tags: + request["Tags"] = tags + + return self.sagemaker_client.create_hub(**request) + + def describe_hub(self, hub_name: str) -> Dict[str, Any]: + """Describes a SageMaker Hub + + Args: + hub_name (str): The name of the hub to describe. + + Returns: + (dict): Return value for ``DescribeHub`` API + """ + request = {"HubName": hub_name} + + return self.sagemaker_client.describe_hub(**request) + + def list_hubs( + self, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + max_schema_version: str = None, + name_contains: str = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """Lists all existing SageMaker Hubs + + Args: + creation_time_after (str): Only list HubContent that was created after + the time specified. + creation_time_before (str): Only list HubContent that was created + before the time specified. + max_results (int): The maximum amount of HubContent to list. + max_schema_version (str): The upper bound of the HubContentSchemaVersion. + name_contains (str): Only list HubContent if the name contains the specified string. + next_token (str): If the response to a previous ``ListHubContents`` request was + truncated, the response includes a ``NextToken``. To retrieve the next set of + hub content, use the token in the next request. + sort_by (str): Sort HubContent versions by either name or creation time. + sort_order (str): Sort Hubs by ascending or descending order. + Returns: + (dict): Return value for ``ListHubs`` API + """ + request = {} + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if name_contains: + request["NameContains"] = name_contains + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hubs(**request) + + def list_hub_contents( + self, + hub_name: str, + hub_content_type: str, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + max_schema_version: str = None, + name_contains: str = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """Lists the HubContents in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub to list the contents of. + hub_content_type (str): The type of the HubContent to list. + creation_time_after (str): Only list HubContent that was created after the + time specified. + creation_time_before (str): Only list HubContent that was created before the + time specified. + max_results (int): The maximum amount of HubContent to list. + max_schema_version (str): The upper bound of the HubContentSchemaVersion. + name_contains (str): Only list HubContent if the name contains the specified string. + next_token (str): If the response to a previous ``ListHubContents`` request was + truncated, the response includes a ``NextToken``. To retrieve the next set of + hub content, use the token in the next request. + sort_by (str): Sort HubContent versions by either name or creation time. + sort_order (str): Sort Hubs by ascending or descending order. + Returns: + (dict): Return value for ``ListHubContents`` API + """ + request = {"HubName": hub_name, "HubContentType": hub_content_type} + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if name_contains: + request["NameContains"] = name_contains + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hub_contents(**request) + + def delete_hub(self, hub_name: str) -> None: + """Deletes a SageMaker Hub + + Args: + hub_name (str): The name of the hub to delete. + """ + request = {"HubName": hub_name} + + return self.sagemaker_client.delete_hub(**request) + + def create_hub_content_reference( + self, + hub_name: str, + source_hub_content_arn: str, + hub_content_name: str = None, + min_version: str = None, + ) -> Dict[str, str]: + """Creates a given HubContent reference in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub that you want to delete content in. + source_hub_content_arn (str): Hub content arn in the public/source Hub. + hub_content_name (str): The name of the reference that you want to add to the Hub. + min_version (str): A minimum version of the hub content to add to the Hub. + + Returns: + (dict): Return value for ``CreateHubContentReference`` API + """ + + request = {"HubName": hub_name, "SageMakerPublicHubContentArn": source_hub_content_arn} + + if hub_content_name: + request["HubContentName"] = hub_content_name + if min_version: + request["MinVersion"] = min_version + + return self.sagemaker_client.create_hub_content_reference(**request) + + def delete_hub_content_reference( + self, hub_name: str, hub_content_type: str, hub_content_name: str + ) -> None: + """Deletes a given HubContent reference in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub that you want to delete content in. + hub_content_type (str): The type of the content that you want to delete from a Hub. + hub_content_name (str): The name of the content that you want to delete from a Hub. + """ + request = { + "HubName": hub_name, + "HubContentType": hub_content_type, + "HubContentName": hub_content_name, + } + + return self.sagemaker_client.delete_hub_content_reference(**request) + + def describe_hub_content( + self, + hub_content_name: str, + hub_content_type: str, + hub_name: str, + hub_content_version: str = None, + ) -> Dict[str, Any]: + """Describes a HubContent in a SageMaker Hub + + Args: + hub_content_name (str): The name of the HubContent to describe. + hub_content_type (str): The type of HubContent in the Hub. + hub_name (str): The name of the Hub that contains the HubContent to describe. + hub_content_version (str): The version of the HubContent to describe + + Returns: + (dict): Return value for ``DescribeHubContent`` API + """ + request = { + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + "HubName": hub_name, + } + if hub_content_version: + request["HubContentVersion"] = hub_content_version + + return self.sagemaker_client.describe_hub_content(**request) + + def list_hub_content_versions( + self, + hub_name, + hub_content_type: str, + hub_content_name: str, + min_version: str = None, + max_schema_version: str = None, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """List all versions of a HubContent in a SageMaker Hub + + Args: + hub_content_name (str): The name of the HubContent to describe. + hub_content_type (str): The type of HubContent in the Hub. + hub_name (str): The name of the Hub that contains the HubContent to describe. + + Returns: + (dict): Return value for ``DescribeHubContent`` API + """ + + request = { + "HubName": hub_name, + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + } + + if min_version: + request["MinVersion"] = min_version + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hub_content_versions(**request) + def get_model_package_args( content_types=None, @@ -6717,6 +7625,8 @@ def get_model_package_args( task=None, skip_model_validation=None, source_uri=None, + model_card=None, + model_life_cycle=None, ): """Get arguments for create_model_package method. @@ -6756,6 +7666,9 @@ def get_model_package_args( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: dict: A dictionary of method argument names and values. @@ -6812,6 +7725,16 @@ def get_model_package_args( model_package_args["skip_model_validation"] = skip_model_validation if source_uri is not None: model_package_args["source_uri"] = source_uri + if model_life_cycle is not None: + model_package_args["model_life_cycle"] = model_life_cycle._to_request_dict() + if model_card is not None: + original_req = model_card._create_request_args() + if original_req.get("ModelCardName") is not None: + del original_req["ModelCardName"] + if original_req.get("Content") is not None: + original_req["ModelCardContent"] = original_req["Content"] + del original_req["Content"] + model_package_args["model_card"] = original_req return model_package_args @@ -6837,6 +7760,8 @@ def get_create_model_package_request( task=None, skip_model_validation="None", source_uri=None, + model_card=None, + model_life_cycle=None, ): """Get request dictionary for CreateModelPackage API. @@ -6874,6 +7799,9 @@ def get_create_model_package_request( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). """ if all([model_package_name, model_package_group_name]): @@ -6971,6 +7899,10 @@ def get_create_model_package_request( request_dict["CertifyForMarketplace"] = marketplace_cert request_dict["ModelApprovalStatus"] = approval_status request_dict["SkipModelValidation"] = skip_model_validation + if model_card is not None: + request_dict["ModelCard"] = model_card + if model_life_cycle is not None: + request_dict["ModelLifeCycle"] = model_life_cycle return request_dict @@ -7124,6 +8056,8 @@ def container_def( container_mode=None, image_config=None, accept_eula=None, + additional_model_data_sources=None, + model_reference_arn=None, ): """Create a definition for executing a container as part of a SageMaker model. @@ -7146,6 +8080,8 @@ def container_def( The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None). + additional_model_data_sources (PipelineVariable or dict): Additional location + of SageMaker model data (default: None). Returns: dict[str, str]: A complete container definition object usable with the CreateModel API if @@ -7155,6 +8091,9 @@ def container_def( env = {} c_def = {"Image": image_uri, "Environment": env} + if additional_model_data_sources: + c_def["AdditionalModelDataSources"] = additional_model_data_sources + if isinstance(model_data_url, str) and ( not (model_data_url.startswith("s3://") and model_data_url.endswith("tar.gz")) or accept_eula is None @@ -7176,6 +8115,11 @@ def container_def( c_def["ModelDataSource"]["S3DataSource"]["ModelAccessConfig"] = { "AcceptEula": accept_eula } + if model_reference_arn: + c_def["ModelDataSource"]["S3DataSource"]["HubAccessConfig"] = { + "HubContentArn": model_reference_arn + } + elif model_data_url is not None: c_def["ModelDataUrl"] = model_data_url @@ -7218,6 +8162,7 @@ def production_variant( container_startup_health_check_timeout=None, managed_instance_scaling=None, routing_config=None, + inference_ami_version=None, ): """Create a production variant description suitable for use in a ``ProductionVariant`` list. @@ -7282,6 +8227,9 @@ def production_variant( RoutingConfig=routing_config, ) + if inference_ami_version: + production_variant_configuration["InferenceAmiVersion"] = inference_ami_version + return production_variant_configuration @@ -7590,6 +8538,31 @@ def _auto_ml_job_status(sagemaker_client, job_name): return desc +def _optimization_job_status(sagemaker_client, job_name): + """Placeholder docstring""" + optimization_job_status_codes = { + "INPROGRESS": ".", + "COMPLETED": "!", + "FAILED": "*", + "STARTING": ".", + "STOPPING": "_", + "STOPPED": "s", + } + in_progress_statuses = ["INPROGRESS", "STARTING", "STOPPING"] + + desc = sagemaker_client.describe_optimization_job(OptimizationJobName=job_name) + status = desc["OptimizationJobStatus"] + + print(optimization_job_status_codes.get(status, "?"), end="") + sys.stdout.flush() + + if status in in_progress_statuses: + return None + + print("") + return desc + + def _create_model_package_status(sagemaker_client, model_package_name): """Placeholder docstring""" in_progress_statuses = ["InProgress", "Pending"] @@ -7908,7 +8881,9 @@ def _logs_for_job( # noqa: C901 - suppress complexity warning for this method """ sagemaker_client = sagemaker_session.sagemaker_client request_end_time = time.time() + timeout if timeout else None - description = sagemaker_client.describe_training_job(TrainingJobName=job_name) + description = _wait_until( + lambda: sagemaker_client.describe_training_job(TrainingJobName=job_name) + ) print(secondary_training_status_message(description, None), end="") instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init( @@ -8054,8 +9029,19 @@ def _check_job_status(job, desc, status_key_name): elif status != "Completed": reason = desc.get("FailureReason", "(No reason provided)") job_type = status_key_name.replace("JobStatus", " job") - message = "Error for {job_type} {job_name}: {status}. Reason: {reason}".format( - job_type=job_type, job_name=job, status=status, reason=reason + troubleshooting = ( + "https://docs.aws.amazon.com/sagemaker/latest/dg/" + "sagemaker-python-sdk-troubleshooting.html" + ) + message = ( + "Error for {job_type} {job_name}: {status}. Reason: {reason}. " + "Check troubleshooting guide for common errors: {troubleshooting}" + ).format( + job_type=job_type, + job_name=job, + status=status, + reason=reason, + troubleshooting=troubleshooting, ) if "CapacityError" in str(reason): raise exceptions.CapacityError( diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 9f4b25f214..586e50da88 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -47,7 +47,7 @@ def __init__( hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, image_uri_region: Optional[str] = None, - **kwargs + **kwargs, ): """Creates a SKLearn Estimator for Scikit-learn environment. @@ -83,8 +83,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code on @@ -154,7 +154,7 @@ def __init__( source_dir, hyperparameters, image_uri=image_uri, - **dict(kwargs, instance_count=1) + **dict(kwargs, instance_count=1), ) if image_uri is None: @@ -174,7 +174,7 @@ def create_model( entry_point=None, source_dir=None, dependencies=None, - **kwargs + **kwargs, ): """Create a SageMaker ``SKLearnModel`` object that can be deployed to an ``Endpoint``. @@ -229,7 +229,7 @@ def create_model( sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), dependencies=(dependencies or self.dependencies), - **kwargs + **kwargs, ) @classmethod diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 27833c1d9c..a9b0e2e8f0 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -23,12 +23,17 @@ from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer from sagemaker.sklearn import defaults from sagemaker.utils import to_string from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -87,9 +92,9 @@ def __init__( framework_version: Optional[str] = None, py_version: str = "py3", image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = SKLearnPredictor, + predictor_cls: Optional[Callable] = SKLearnPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, - **kwargs + **kwargs, ): """Initialize an SKLearnModel. @@ -117,7 +122,7 @@ def __init__( If ``framework_version`` or ``py_version`` are ``None``, then ``image_uri`` is required. If ``image_uri`` is also ``None``, then a ``ValueError`` will be raised. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -172,6 +177,8 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -223,6 +230,9 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -263,6 +273,8 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + model_life_cycle=model_life_cycle, ) def prepare_container_def( @@ -271,6 +283,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Container definition with framework configuration set in model environment variables. @@ -320,6 +333,7 @@ def prepare_container_def( model_data_uri, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): diff --git a/src/sagemaker/telemetry/__init__.py b/src/sagemaker/telemetry/__init__.py new file mode 100644 index 0000000000..ada3f1f09f --- /dev/null +++ b/src/sagemaker/telemetry/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Placeholder docstring""" +from __future__ import absolute_import + +from .telemetry_logging import _telemetry_emitter # noqa: F401 diff --git a/src/sagemaker/telemetry/constants.py b/src/sagemaker/telemetry/constants.py new file mode 100644 index 0000000000..6766d45b4e --- /dev/null +++ b/src/sagemaker/telemetry/constants.py @@ -0,0 +1,82 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Constants used in SageMaker Python SDK telemetry.""" + +from __future__ import absolute_import +from enum import Enum + +# Default AWS region used by SageMaker +DEFAULT_AWS_REGION = "us-west-2" + + +class Feature(Enum): + """Enumeration of feature names used in telemetry.""" + + SDK_DEFAULTS = 1 + LOCAL_MODE = 2 + REMOTE_FUNCTION = 3 + MODEL_TRAINER = 4 + ESTIMATOR = 5 + HYPERPOD = 6 # Added to support telemetry in sagemaker-hyperpod-cli + + def __str__(self): # pylint: disable=E0307 + """Return the feature name.""" + return self.name + + +class Status(Enum): + """Enumeration of status values used in telemetry.""" + + SUCCESS = 1 + FAILURE = 0 + + def __str__(self): # pylint: disable=E0307 + """Return the status name.""" + return self.name + + +class Region(str, Enum): + """Telemetry: List of all supported AWS regions.""" + + # Classic + US_EAST_1 = "us-east-1" # IAD + US_EAST_2 = "us-east-2" # CMH + US_WEST_1 = "us-west-1" # SFO + US_WEST_2 = "us-west-2" # PDX + AP_NORTHEAST_1 = "ap-northeast-1" # NRT + AP_NORTHEAST_2 = "ap-northeast-2" # ICN + AP_NORTHEAST_3 = "ap-northeast-3" # KIX + AP_SOUTH_1 = "ap-south-1" # BOM + AP_SOUTHEAST_1 = "ap-southeast-1" # SIN + AP_SOUTHEAST_2 = "ap-southeast-2" # SYD + CA_CENTRAL_1 = "ca-central-1" # YUL + EU_CENTRAL_1 = "eu-central-1" # FRA + EU_NORTH_1 = "eu-north-1" # ARN + EU_WEST_1 = "eu-west-1" # DUB + EU_WEST_2 = "eu-west-2" # LHR + EU_WEST_3 = "eu-west-3" # CDG + SA_EAST_1 = "sa-east-1" # GRU + # Opt-in + AP_EAST_1 = "ap-east-1" # HKG + AP_SOUTHEAST_3 = "ap-southeast-3" # CGK + AF_SOUTH_1 = "af-south-1" # CPT + EU_SOUTH_1 = "eu-south-1" # MXP + ME_SOUTH_1 = "me-south-1" # BAH + MX_CENTRAL_1 = "mx-central-1" # QRO + AP_SOUTHEAST_7 = "ap-southeast-7" # BKK + AP_SOUTH_2 = "ap-south-2" # HYD + AP_SOUTHEAST_4 = "ap-southeast-4" # MEL + EU_CENTRAL_2 = "eu-central-2" # ZRH + EU_SOUTH_2 = "eu-south-2" # ZAZ + IL_CENTRAL_1 = "il-central-1" # TLV + ME_CENTRAL_1 = "me-central-1" # DXB diff --git a/src/sagemaker/telemetry/telemetry_logging.py b/src/sagemaker/telemetry/telemetry_logging.py new file mode 100644 index 0000000000..990e12124f --- /dev/null +++ b/src/sagemaker/telemetry/telemetry_logging.py @@ -0,0 +1,285 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Telemetry module for SageMaker Python SDK to collect usage data and metrics.""" +from __future__ import absolute_import +import logging +import platform +import sys +from time import perf_counter +from typing import List +import functools +import requests + +import boto3 +from sagemaker.session import Session +from sagemaker.utils import resolve_value_from_config +from sagemaker.config.config_schema import TELEMETRY_OPT_OUT_PATH +from sagemaker.telemetry.constants import ( + Feature, + Status, + Region, + DEFAULT_AWS_REGION, +) +from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file + +logger = logging.getLogger(__name__) + +OS_NAME = platform.system() or "UnresolvedOS" +OS_VERSION = platform.release() or "UnresolvedOSVersion" +OS_NAME_VERSION = "{}/{}".format(OS_NAME, OS_VERSION) +PYTHON_VERSION = "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro +) + +TELEMETRY_OPT_OUT_MESSAGING = ( + "SageMaker Python SDK will collect telemetry to help us better understand our user's needs, " + "diagnose issues, and deliver additional features.\n" + "To opt out of telemetry, please disable via TelemetryOptOut parameter in SDK defaults config. " + "For more information, refer to https://sagemaker.readthedocs.io/en/stable/overview.html" + "#configuring-and-using-defaults-with-the-sagemaker-python-sdk." +) + +FEATURE_TO_CODE = { + str(Feature.SDK_DEFAULTS): 1, + str(Feature.LOCAL_MODE): 2, + str(Feature.REMOTE_FUNCTION): 3, + str(Feature.MODEL_TRAINER): 4, + str(Feature.ESTIMATOR): 5, + str(Feature.HYPERPOD): 6, # Added to support telemetry in sagemaker-hyperpod-cli +} + +STATUS_TO_CODE = { + str(Status.SUCCESS): 1, + str(Status.FAILURE): 0, +} + + +def _telemetry_emitter(feature: str, func_name: str): + """Telemetry Emitter + + Decorator to emit telemetry logs for SageMaker Python SDK functions. This class needs + sagemaker_session object as a member. Default session object is a pysdk v2 Session object + in this repo. When collecting telemetry for classes using sagemaker-core Session object, + we should be aware of its differences, such as sagemaker_session.sagemaker_config does not + exist in new Session class. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + sagemaker_session = None + if len(args) > 0 and hasattr(args[0], "sagemaker_session"): + # Get the sagemaker_session from the instance method args + sagemaker_session = args[0].sagemaker_session + elif feature == Feature.REMOTE_FUNCTION: + # Get the sagemaker_session from the function keyword arguments for remote function + sagemaker_session = kwargs.get( + "sagemaker_session", _get_default_sagemaker_session() + ) + + if sagemaker_session: + logger.debug("sagemaker_session found, preparing to emit telemetry...") + logger.info(TELEMETRY_OPT_OUT_MESSAGING) + response = None + caught_ex = None + studio_app_type = process_studio_metadata_file() + + # Check if telemetry is opted out + telemetry_opt_out_flag = resolve_value_from_config( + direct_input=None, + config_path=TELEMETRY_OPT_OUT_PATH, + default_value=False, + sagemaker_session=sagemaker_session, + ) + logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag) + + # Construct the feature list to track feature combinations + feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]] + + if ( + hasattr(sagemaker_session, "sagemaker_config") + and sagemaker_session.sagemaker_config + and feature != Feature.SDK_DEFAULTS + ): + feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)]) + + if ( + hasattr(sagemaker_session, "local_mode") + and sagemaker_session.local_mode + and feature != Feature.LOCAL_MODE + ): + feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)]) + + # Construct the extra info to track platform and environment usage metadata + extra = ( + f"{func_name}" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-env={PYTHON_VERSION}" + f"&x-sys={OS_NAME_VERSION}" + f"&x-platform={studio_app_type}" + ) + + # Add endpoint ARN to the extra info if available + if hasattr(sagemaker_session, "endpoint_arn") and sagemaker_session.endpoint_arn: + extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}" + + start_timer = perf_counter() + try: + # Call the original function + response = func(*args, **kwargs) + stop_timer = perf_counter() + elapsed = stop_timer - start_timer + extra += f"&x-latency={round(elapsed, 2)}" + if not telemetry_opt_out_flag: + _send_telemetry_request( + STATUS_TO_CODE[str(Status.SUCCESS)], + feature_list, + sagemaker_session, + None, + None, + extra, + ) + except Exception as e: # pylint: disable=W0703 + stop_timer = perf_counter() + elapsed = stop_timer - start_timer + extra += f"&x-latency={round(elapsed, 2)}" + if not telemetry_opt_out_flag: + _send_telemetry_request( + STATUS_TO_CODE[str(Status.FAILURE)], + feature_list, + sagemaker_session, + str(e), + e.__class__.__name__, + extra, + ) + caught_ex = e + finally: + if caught_ex: + raise caught_ex + return response # pylint: disable=W0150 + else: + logger.debug( + "Unable to send telemetry for function %s. " + "sagemaker_session is not provided or not valid.", + func_name, + ) + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def _send_telemetry_request( + status: int, + feature_list: List[int], + session: Session, + failure_reason: str = None, + failure_type: str = None, + extra_info: str = None, +) -> None: + """Make GET request to an empty object in S3 bucket""" + try: + accountId = _get_accountId(session) if session else "NotAvailable" + region = _get_region_or_default(session) + + try: + Region(region) # Validate the region + except ValueError: + logger.warning( + "Region not found in supported regions. Telemetry request will not be emitted." + ) + return + + url = _construct_url( + accountId, + region, + str(status), + str( + ",".join(map(str, feature_list)) + ), # Remove brackets and quotes to cut down on length + failure_reason, + failure_type, + extra_info, + ) + # Send the telemetry request + logger.debug("Sending telemetry request to [%s]", url) + _requests_helper(url, 2) + logger.debug("SageMaker Python SDK telemetry successfully emitted.") + except Exception: # pylint: disable=W0703 + logger.debug("SageMaker Python SDK telemetry not emitted!") + + +def _construct_url( + accountId: str, + region: str, + status: str, + feature: str, + failure_reason: str, + failure_type: str, + extra_info: str, +) -> str: + """Construct the URL for the telemetry request""" + + base_url = ( + f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?" + f"x-accountId={accountId}" + f"&x-status={status}" + f"&x-feature={feature}" + ) + logger.debug("Failure reason: %s", failure_reason) + if failure_reason: + base_url += f"&x-failureReason={failure_reason}" + base_url += f"&x-failureType={failure_type}" + if extra_info: + base_url += f"&x-extra={extra_info}" + return base_url + + +def _requests_helper(url, timeout): + """Make a GET request to the given URL""" + + response = None + try: + response = requests.get(url, timeout) + except requests.exceptions.RequestException as e: + logger.exception("Request exception: %s", str(e)) + return response + + +def _get_accountId(session): + """Return the account ID from the boto session""" + + try: + sts = session.boto_session.client("sts") + return sts.get_caller_identity()["Account"] + except Exception: # pylint: disable=W0703 + return None + + +def _get_region_or_default(session): + """Return the region name from the boto session or default to us-west-2""" + + try: + return session.boto_session.region_name + except Exception: # pylint: disable=W0703 + return DEFAULT_AWS_REGION + + +def _get_default_sagemaker_session(): + """Return the default sagemaker session""" + + boto_session = boto3.Session(region_name=DEFAULT_AWS_REGION) + sagemaker_session = Session(boto_session=boto_session) + + return sagemaker_session diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 77f162207c..b384cbbbb5 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import sagemaker from sagemaker import image_uris, s3, ModelMetrics @@ -22,12 +22,17 @@ from sagemaker.deprecations import removed_kwargs from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.utils import format_tags +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger(__name__) @@ -57,9 +62,9 @@ def __init__( manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. - serializer (callable): Optional. Default serializes input data to + serializer (Callable): Optional. Default serializes input data to json. Handles dicts, lists, and numpy arrays. - deserializer (callable): Optional. Default parses the response using + deserializer (Callable): Optional. Default parses the response using ``json.load(...)``. model_name (str): Optional. The name of the SavedModel model that should handle the request. If not specified, the endpoint's @@ -141,7 +146,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, container_log_level: Optional[int] = None, - predictor_cls: callable = TensorFlowPredictor, + predictor_cls: Optional[Callable] = TensorFlowPredictor, **kwargs, ): """Initialize a Model. @@ -169,7 +174,7 @@ def __init__( container_log_level (int): Log level to use within the container (default: logging.ERROR). Valid values are defined in the Python logging module. - predictor_cls (callable[str, sagemaker.session.Session]): A function + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. @@ -195,7 +200,7 @@ def __init__( # patch versions, but end up hosting the model of same TF version. For eg., the upstream # TFS-2.12.0 release was a bad release and hence a new TFS-2.12.1 release was made to host # models from TF-2.12.0. - training_inference_version_mismatch_dict = {"2.12.0": "2.12.1"} + training_inference_version_mismatch_dict = {"2.12.0": "2.12.1", "2.16.2": "2.16.1"} self.inference_framework_version = training_inference_version_mismatch_dict.get( framework_version, framework_version ) @@ -234,6 +239,8 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -285,6 +292,9 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -325,6 +335,8 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + model_life_cycle=model_life_cycle, ) def deploy( @@ -346,6 +358,7 @@ def deploy( container_startup_health_check_timeout=None, inference_recommendation_id=None, explainer_config=None, + update_endpoint: Optional[bool] = False, **kwargs, ): """Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``.""" @@ -371,6 +384,7 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, inference_recommendation_id=inference_recommendation_id, explainer_config=explainer_config, + update_endpoint=update_endpoint, **kwargs, ) @@ -389,6 +403,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Prepare the container definition. @@ -465,6 +480,7 @@ def prepare_container_def( model_data, env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def _get_container_env(self): diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 4b0f38f36f..d9b052770b 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -18,21 +18,20 @@ import inspect import json import logging - from enum import Enum -from typing import Union, Dict, Optional, List, Set +from typing import Dict, List, Optional, Set, Union import sagemaker from sagemaker.amazon.amazon_estimator import ( - RecordSet, AmazonAlgorithmEstimatorBase, FileSystemRecordSet, + RecordSet, ) from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.analytics import HyperparameterTuningJobAnalytics from sagemaker.deprecations import removed_function -from sagemaker.estimator import Framework, EstimatorBase -from sagemaker.inputs import TrainingInput, FileSystemInput +from sagemaker.estimator import EstimatorBase, Framework +from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.job import _Job from sagemaker.jumpstart.utils import ( add_jumpstart_uri_tags, @@ -44,18 +43,17 @@ IntegerParameter, ParameterRange, ) -from sagemaker.workflow.entities import PipelineVariable -from sagemaker.workflow.pipeline_context import runnable_by_pipeline - from sagemaker.session import Session from sagemaker.utils import ( + Tags, base_from_name, base_name_from_image, + format_tags, name_from_base, to_string, - format_tags, - Tags, ) +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.pipeline_context import runnable_by_pipeline AMAZON_ESTIMATOR_MODULE = "sagemaker" AMAZON_ESTIMATOR_CLS_NAMES = { @@ -133,15 +131,12 @@ def __init__( if warm_start_type not in list(WarmStartTypes): raise ValueError( - "Invalid type: {}, valid warm start types are: {}".format( - warm_start_type, list(WarmStartTypes) - ) + f"Invalid type: {warm_start_type}, " + f"valid warm start types are: {list(WarmStartTypes)}" ) if not parents: - raise ValueError( - "Invalid parents: {}, parents should not be None/empty".format(parents) - ) + raise ValueError(f"Invalid parents: {parents}, parents should not be None/empty") self.type = warm_start_type self.parents = set(parents) @@ -1455,9 +1450,7 @@ def _get_best_training_job(self): return tuning_job_describe_result["BestTrainingJob"] except KeyError: raise Exception( - "Best training job not available for tuning job: {}".format( - self.latest_tuning_job.name - ) + f"Best training job not available for tuning job: {self.latest_tuning_job.name}" ) def _ensure_last_tuning_job(self): @@ -1920,8 +1913,11 @@ def create( :meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches. If not specified, a default job name is generated, based on the training image name and current timestamp. - strategy (str): Strategy to be used for hyperparameter estimations - (default: 'Bayesian'). + strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations. + More information about different strategies: + https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html. + Available options are: 'Bayesian', 'Random', 'Hyperband', + 'Grid' (default: 'Bayesian') strategy_config (dict): The configuration for a training job launched by a hyperparameter tuning job. completion_criteria_config (dict): The configuration for tuning job completion criteria. @@ -2080,21 +2076,19 @@ def _validate_dict_argument(cls, name, value, allowed_keys, require_same_keys=Fa return if not isinstance(value, dict): - raise ValueError( - "Argument '{}' must be a dictionary using {} as keys".format(name, allowed_keys) - ) + raise ValueError(f"Argument '{name}' must be a dictionary using {allowed_keys} as keys") value_keys = sorted(value.keys()) if require_same_keys: if value_keys != allowed_keys: raise ValueError( - "The keys of argument '{}' must be the same as {}".format(name, allowed_keys) + f"The keys of argument '{name}' must be the same as {allowed_keys}" ) else: if not set(value_keys).issubset(set(allowed_keys)): raise ValueError( - "The keys of argument '{}' must be a subset of {}".format(name, allowed_keys) + f"The keys of argument '{name}' must be a subset of {allowed_keys}" ) def _add_estimator( @@ -2123,6 +2117,72 @@ def _add_estimator( delete_endpoint = removed_function("delete_endpoint") + @staticmethod + def visualize_jobs( + tuning_jobs: Union[ + str, + "sagemaker.tuner.HyperparameterTuner", + List[Union[str, "sagemaker.tuner.HyperparameterTuner"]], + ], + return_dfs: bool = False, + job_metrics: Optional[List[str]] = None, + trials_only: bool = False, + advanced: bool = False, + ): + """Create interactive visualization via altair charts using the sagemaker.amtviz package. + + Args: + tuning_jobs (str or sagemaker.tuner.HyperparameterTuner or list[str, sagemaker.tuner.HyperparameterTuner]): + One or more tuning jobs to create + visualization for. + return_dfs: (bool): Option to return trials and full dataframe. + job_metrics: (list[str]): Metrics to be used in charts. + trials_only: (bool): Whether to show trials only or full dataframe. + advanced: (bool): Show a cumulative step line in the progress over time chart. + Returns: + A collection of charts (altair.VConcatChart); or charts, trials_df (pandas.DataFrame), + full_df (pandas.DataFrame) if ``return_dfs=True``. + """ + try: + # Check if altair is installed + importlib.import_module("altair") + + except ImportError: + print("Altair is not installed. Install Altair to use the visualization feature:") + print(" pip install altair") + print("After installing Altair, use the methods visualize_jobs or visualize_job.") + return None + + # If altair is installed, proceed with visualization + from sagemaker.amtviz import visualize_tuning_job + + return visualize_tuning_job( + tuning_jobs, + return_dfs=return_dfs, + job_metrics=job_metrics, + trials_only=trials_only, + advanced=advanced, + ) + + def visualize_job( + self, + return_dfs: bool = False, + job_metrics: Optional[List[str]] = None, + trials_only: bool = False, + advanced: bool = False, + ): + """Convenience method on instance level for visualize_jobs(). + + See static method visualize_jobs(). + """ + return HyperparameterTuner.visualize_jobs( + self, + return_dfs=return_dfs, + job_metrics=job_metrics, + trials_only=trials_only, + advanced=advanced, + ) + class _TuningJob(_Job): """Placeholder docstring""" diff --git a/src/sagemaker/user_agent.py b/src/sagemaker/user_agent.py index 8af89696c2..c1b2bcac07 100644 --- a/src/sagemaker/user_agent.py +++ b/src/sagemaker/user_agent.py @@ -13,8 +13,6 @@ """Placeholder docstring""" from __future__ import absolute_import -import platform -import sys import json import os @@ -28,12 +26,6 @@ STUDIO_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json" SDK_VERSION = importlib_metadata.version("sagemaker") -OS_NAME = platform.system() or "UnresolvedOS" -OS_VERSION = platform.release() or "UnresolvedOSVersion" -OS_NAME_VERSION = "{}/{}".format(OS_NAME, OS_VERSION) -PYTHON_VERSION = "Python/{}.{}.{}".format( - sys.version_info.major, sys.version_info.minor, sys.version_info.micro -) def process_notebook_metadata_file(): @@ -63,45 +55,24 @@ def process_studio_metadata_file(): return None -def determine_prefix(user_agent=""): - """Determines the prefix for the user agent string. +def get_user_agent_extra_suffix(): + """Get the user agent extra suffix string specific to SageMaker Python SDK - Args: - user_agent (str): The user agent string to prepend the prefix to. + Adhers to new boto recommended User-Agent 2.0 header format Returns: - str: The user agent string with the prefix prepended. + str: The user agent extra suffix string to be appended """ - prefix = "{}/{}".format(SDK_PREFIX, SDK_VERSION) - - if PYTHON_VERSION not in user_agent: - prefix = "{} {}".format(prefix, PYTHON_VERSION) - - if OS_NAME_VERSION not in user_agent: - prefix = "{} {}".format(prefix, OS_NAME_VERSION) + suffix = "lib/{}#{}".format(SDK_PREFIX, SDK_VERSION) # Get the notebook instance type and prepend it to the user agent string if exists notebook_instance_type = process_notebook_metadata_file() if notebook_instance_type: - prefix = "{} {}/{}".format(prefix, NOTEBOOK_PREFIX, notebook_instance_type) + suffix = "{} md/{}#{}".format(suffix, NOTEBOOK_PREFIX, notebook_instance_type) # Get the studio app type and prepend it to the user agent string if exists studio_app_type = process_studio_metadata_file() if studio_app_type: - prefix = "{} {}/{}".format(prefix, STUDIO_PREFIX, studio_app_type) - - return prefix - - -def prepend_user_agent(client): - """Prepends the user agent string with the SageMaker Python SDK version. - - Args: - client (botocore.client.BaseClient): The client to prepend the user agent string for. - """ - prefix = determine_prefix(client._client_config.user_agent) + suffix = "{} md/{}#{}".format(suffix, STUDIO_PREFIX, studio_app_type) - if client._client_config.user_agent is None: - client._client_config.user_agent = prefix - else: - client._client_config.user_agent = "{} {}".format(prefix, client._client_config.user_agent) + return suffix diff --git a/src/sagemaker/utilities/search_expression.py b/src/sagemaker/utilities/search_expression.py index 5b2aaf3226..d59ee76277 100644 --- a/src/sagemaker/utilities/search_expression.py +++ b/src/sagemaker/utilities/search_expression.py @@ -108,7 +108,7 @@ def __init__( nested_filters=None, sub_expressions=None, boolean_operator=BooleanOperator.AND, - **kwargs + **kwargs, ): """Construct a Search Expression object diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 0436c0afea..2a31dfab04 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -13,10 +13,12 @@ """Placeholder docstring""" from __future__ import absolute_import +import abc import contextlib import copy import errno import inspect +import json import logging import os import random @@ -25,29 +27,41 @@ import tarfile import tempfile import time -from typing import Union, Any, List, Optional, Dict -import json -import abc import uuid from datetime import datetime -from os.path import abspath, realpath, dirname, normpath, join as joinpath - +from functools import lru_cache from importlib import import_module +from os.path import abspath, dirname +from os.path import join as joinpath +from os.path import normpath, realpath +from typing import Any, Dict, List, Optional, Union + +import boto3 import botocore from botocore.utils import merge_dicts +from six import viewitems from six.moves.urllib import parse -import pandas as pd from sagemaker import deprecations from sagemaker.config import validate_sagemaker_config from sagemaker.config.config_utils import ( - _log_sagemaker_config_single_substitution, _log_sagemaker_config_merge, + _log_sagemaker_config_single_substitution, ) +from sagemaker.enums import RoutingStrategy from sagemaker.session_settings import SessionSettings -from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string +from sagemaker.workflow import is_pipeline_parameter_string, is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +ALTERNATE_DOMAINS = { + "cn-north-1": "amazonaws.com.cn", + "cn-northwest-1": "amazonaws.com.cn", + "us-iso-east-1": "c2s.ic.gov", + "us-isob-east-1": "sc2s.sgov.gov", + "us-isof-south-1": "csp.hci.ic.gov", + "us-isof-east-1": "csp.hci.ic.gov", +} + ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" MODEL_PACKAGE_ARN_PATTERN = ( r"arn:aws([a-z\-]*)?:sagemaker:([a-z0-9\-]*):([0-9]{12}):model-package/(.*)" @@ -384,8 +398,7 @@ def download_folder(bucket_name, prefix, target, sagemaker_session): sagemaker_session (sagemaker.session.Session): a sagemaker session to interact with S3. """ - boto_session = sagemaker_session.boto_session - s3 = boto_session.resource("s3", region_name=boto_session.region_name) + s3 = sagemaker_session.s3_resource prefix = prefix.lstrip("/") @@ -612,7 +625,24 @@ def _create_or_update_code_dir( if os.path.exists(os.path.join(code_dir, inference_script)): pass else: - raise + raise FileNotFoundError( + f"Could not find '{inference_script}'. Common solutions:\n" + "1. Make sure inference.py exists in the code/ directory\n" + "2. Package your model correctly:\n" + " - ✅ DO: Navigate to the directory containing model files and run:\n" + " cd /path/to/model_files\n" + " tar czvf ../model.tar.gz *\n" + " - ❌ DON'T: Create from parent directory:\n" + " tar czvf model.tar.gz model/\n" + "\nExpected structure in model.tar.gz:\n" + " ├── model.pth (or your model file)\n" + " └── code/\n" + " ├── inference.py\n" + " └── requirements.txt\n" + "\nFor more details, see the documentation:\n" + + "https://sagemaker.readthedocs.io/en/stable/" + + "frameworks/pytorch/using_pytorch.html#bring-your-own-model" + ) for dependency in dependencies: lib_dir = os.path.join(code_dir, "lib") @@ -713,7 +743,7 @@ def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code """Retry with backoff until maximum attempts are reached Args: - callable_func (callable): The callable function to retry. + callable_func (Callable): The callable function to retry. num_attempts (int): The maximum number of attempts to retry.(Default: 8) botocore_client_error_code (str): The specific Botocore ClientError exception error code on which to retry on. @@ -1147,7 +1177,7 @@ def get_sagemaker_config_value(sagemaker_session, key, sagemaker_config: dict = Returns: object: The corresponding default value in the configuration file. """ - if sagemaker_session: + if sagemaker_session and hasattr(sagemaker_session, "sagemaker_config"): config_to_check = sagemaker_session.sagemaker_config else: config_to_check = sagemaker_config @@ -1450,10 +1480,15 @@ def volume_size_supported(instance_type: str) -> bool: if len(parts) != 2: raise ValueError(f"Failed to parse instance type '{instance_type}'") - # Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) + g5 - # does not support attaching an EBS volume. + # Any instance type with a "d" in the instance family (i.e. c5d, p4d, etc) + # + g5 or g6 or p5 does not support attaching an EBS volume. family = parts[0] - return "d" not in family and not family.startswith("g5") + + unsupported_families = ["g5", "g6", "p5", "trn1"] + + return "d" not in family and not any( + family.startswith(prefix) for prefix in unsupported_families + ) except Exception as e: raise ValueError(f"Failed to parse instance type '{instance_type}': {str(e)}") @@ -1467,6 +1502,24 @@ def instance_supports_kms(instance_type: str) -> bool: return volume_size_supported(instance_type) +def get_training_job_name_from_training_job_arn(training_job_arn: str) -> str: + """Extract Training job name from Training job arn. + + Args: + training_job_arn: Training job arn. + + Returns: Training job name. + + """ + if training_job_arn is None: + return None + pattern = "arn:aws[a-z-]*:sagemaker:[a-z0-9-]*:[0-9]{12}:training-job/(.+)" + match = re.match(pattern, training_job_arn) + if match: + return match.group(1) + return None + + def get_instance_type_family(instance_type: str) -> str: """Return the family of the instance type. @@ -1602,44 +1655,78 @@ def can_model_package_source_uri_autopopulate(source_uri: str): ) -def flatten_dict(source_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: - """Flatten a nested dictionary. +def flatten_dict( + d: Dict[str, Any], + max_flatten_depth=None, +) -> Dict[str, Any]: + """Flatten a dictionary object. - Args: - source_dict (dict): The dictionary to be flattened. - sep (str): The separator to be used in the flattened dictionary. - Returns: - transformed_dict: The flattened dictionary. + d (Dict[str, Any]): + The dict that will be flattened. + max_flatten_depth (Optional[int]): + Maximum depth to merge. """ - flat_dict_list = pd.json_normalize(source_dict, sep=sep).to_dict(orient="records") - if flat_dict_list: - return flat_dict_list[0] - return {} + def tuple_reducer(k1, k2): + if k1 is None: + return (k2,) + return k1 + (k2,) -def unflatten_dict(source_dict: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: - """Unflatten a flattened dictionary back into a nested dictionary. + # check max_flatten_depth + if max_flatten_depth is not None and max_flatten_depth < 1: + raise ValueError("max_flatten_depth should not be less than 1.") - Args: - source_dict (dict): The input flattened dictionary. - sep (str): The separator used in the flattened keys. + reducer = tuple_reducer - Returns: - transformed_dict: The reconstructed nested dictionary. + flat_dict = {} + + def _flatten(_d, depth, parent=None): + key_value_iterable = viewitems(_d) + has_item = False + for key, value in key_value_iterable: + has_item = True + flat_key = reducer(parent, key) + if isinstance(value, dict) and (max_flatten_depth is None or depth < max_flatten_depth): + has_child = _flatten(value, depth=depth + 1, parent=flat_key) + if has_child: + continue + + if flat_key in flat_dict: + raise ValueError("duplicated key '{}'".format(flat_key)) + flat_dict[flat_key] = value + + return has_item + + _flatten(d, depth=1) + return flat_dict + + +def nested_set_dict(d: Dict[str, Any], keys: List[str], value: Any) -> None: + """Set a value to a sequence of nested keys.""" + + key = keys[0] + + if len(keys) == 1: + d[key] = value + return + + d = d.setdefault(key, {}) + nested_set_dict(d, keys[1:], value) + + +def unflatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """Unflatten dict-like object. + + d (Dict[str, Any]) : + The dict that will be unflattened. """ - if not source_dict: - return {} - result = {} - for key, value in source_dict.items(): - keys = key.split(sep) - current = result - for k in keys[:-1]: - if k not in current: - current[k] = {} - current = current[k] if current[k] is not None else current - current[keys[-1]] = value - return result + unflattened_dict = {} + for flat_key, value in viewitems(d): + key_tuple = flat_key + nested_set_dict(unflattened_dict, key_tuple, value) + + return unflattened_dict def deep_override_dict( @@ -1650,8 +1737,224 @@ def deep_override_dict( skip_keys = [] flattened_dict1 = flatten_dict(dict1) + flattened_dict1 = {key: value for key, value in flattened_dict1.items() if value is not None} flattened_dict2 = flatten_dict( {key: value for key, value in dict2.items() if key not in skip_keys} ) flattened_dict1.update(flattened_dict2) return unflatten_dict(flattened_dict1) if flattened_dict1 else {} + + +def _resolve_routing_config(routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """Resolve Routing Config + + Args: + routing_config (Optional[Dict[str, Any]]): The routing config. + + Returns: + Optional[Dict[str, Any]]: The resolved routing config. + + Raises: + ValueError: If the RoutingStrategy is invalid. + """ + + if routing_config: + routing_strategy = routing_config.get("RoutingStrategy", None) + if routing_strategy: + if isinstance(routing_strategy, RoutingStrategy): + return {"RoutingStrategy": routing_strategy.name} + if isinstance(routing_strategy, str) and ( + routing_strategy.upper() == RoutingStrategy.RANDOM.name + or routing_strategy.upper() == RoutingStrategy.LEAST_OUTSTANDING_REQUESTS.name + ): + return {"RoutingStrategy": routing_strategy.upper()} + raise ValueError( + "RoutingStrategy must be either RoutingStrategy.RANDOM " + "or RoutingStrategy.LEAST_OUTSTANDING_REQUESTS" + ) + return None + + +@lru_cache +def get_instance_rate_per_hour( + instance_type: str, + region: str, +) -> Optional[Dict[str, str]]: + """Gets instance rate per hour for the given instance type. + + Args: + instance_type (str): The instance type. + region (str): The region. + Returns: + Optional[Dict[str, str]]: Instance rate per hour. + Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.125'}. + + Raises: + Exception: An exception is raised if + the IAM role is not authorized to perform pricing:GetProducts. + or unexpected event happened. + """ + region_name = "us-east-1" + if region.startswith("eu") or region.startswith("af"): + region_name = "eu-central-1" + elif region.startswith("ap") or region.startswith("cn"): + region_name = "ap-south-1" + + pricing_client: boto3.client = boto3.client("pricing", region_name=region_name) + res = pricing_client.get_products( + ServiceCode="AmazonSageMaker", + Filters=[ + {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type}, + {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"}, + {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region}, + ], + ) + + price_list = res.get("PriceList", []) + if len(price_list) > 0: + price_data = price_list[0] + if isinstance(price_data, str): + price_data = json.loads(price_data) + + instance_rate_per_hour = extract_instance_rate_per_hour(price_data) + if instance_rate_per_hour is not None: + return instance_rate_per_hour + raise Exception(f"Unable to get instance rate per hour for instance type: {instance_type}.") + + +def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[str, str]]: + """Extract instance rate per hour for the given Price JSON data. + + Args: + price_data (Dict[str, Any]): The Price JSON data. + Returns: + Optional[Dict[str, str], None]: Instance rate per hour. + """ + + if price_data is not None: + price_dimensions = price_data.get("terms", {}).get("OnDemand", {}).values() + for dimension in price_dimensions: + for price in dimension.get("priceDimensions", {}).values(): + for currency in price.get("pricePerUnit", {}).keys(): + value = price.get("pricePerUnit", {}).get(currency) + if value is not None: + value = str(round(float(value), 3)) + return { + "unit": f"{currency}/Hr", + "value": value, + "name": "On-demand Instance Rate", + } + return None + + +def camel_case_to_pascal_case(data: Dict[str, Any]) -> Dict[str, Any]: + """Iteratively updates a dictionary to convert all keys from snake_case to PascalCase. + + Args: + data (dict): The dictionary to be updated. + + Returns: + dict: The updated dictionary with keys in PascalCase. + """ + result = {} + + def convert_key(key): + """Converts a snake_case key to PascalCase.""" + return "".join(part.capitalize() for part in key.split("_")) + + def convert_value(value): + """Recursively processes the value of a key-value pair.""" + if isinstance(value, dict): + return camel_case_to_pascal_case(value) + if isinstance(value, list): + return [convert_value(item) for item in value] + + return value + + for key, value in data.items(): + result[convert_key(key)] = convert_value(value) + + return result + + +def tag_exists(tag: TagsDict, curr_tags: Optional[Tags]) -> bool: + """Returns True if ``tag`` already exists. + + Args: + tag (TagsDict): The tag dictionary. + curr_tags (Optional[Tags]): The current tags. + + Returns: + bool: True if the tag exists. + """ + if curr_tags is None: + return False + + for curr_tag in curr_tags: + if tag["Key"] == curr_tag["Key"]: + return True + + return False + + +def _validate_new_tags(new_tags: Optional[Tags], curr_tags: Optional[Tags]) -> Optional[Tags]: + """Validates new tags against existing tags. + + Args: + new_tags (Optional[Tags]): The new tags. + curr_tags (Optional[Tags]): The current tags. + + Returns: + Optional[Tags]: The updated tags. + """ + if curr_tags is None: + return new_tags + + if curr_tags and isinstance(curr_tags, dict): + curr_tags = [curr_tags] + + if isinstance(new_tags, dict): + if not tag_exists(new_tags, curr_tags): + curr_tags.append(new_tags) + elif isinstance(new_tags, list): + for new_tag in new_tags: + if not tag_exists(new_tag, curr_tags): + curr_tags.append(new_tag) + + return curr_tags + + +def remove_tag_with_key(key: str, tags: Optional[Tags]) -> Optional[Tags]: + """Remove a tag with the given key from the list of tags. + + Args: + key (str): The key of the tag to remove. + tags (Optional[Tags]): The current list of tags. + + Returns: + Optional[Tags]: The updated list of tags with the tag removed. + """ + if tags is None: + return tags + if isinstance(tags, dict): + tags = [tags] + + updated_tags = [] + for tag in tags: + if tag["Key"] != key: + updated_tags.append(tag) + + if not updated_tags: + return None + if len(updated_tags) == 1: + return updated_tags[0] + return updated_tags + + +def get_domain_for_region(region: str) -> str: + """Returns the domain for the given region. + + Args: + region (str): AWS region name. + """ + return ALTERNATE_DOMAINS.get(region, "amazonaws.com") diff --git a/src/sagemaker/workflow/_repack_model.py b/src/sagemaker/workflow/_repack_model.py index 84b3a426f6..2a9129da2f 100644 --- a/src/sagemaker/workflow/_repack_model.py +++ b/src/sagemaker/workflow/_repack_model.py @@ -27,13 +27,6 @@ # is unpacked for inference, the custom entry point will be used. # Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html -# distutils.dir_util.copy_tree works way better than the half-baked -# shutil.copytree which bombs on previously existing target dirs... -# alas ... https://bugs.python.org/issue10948 -# we'll go ahead and use the copy_tree function anyways because this -# repacking is some short-lived hackery, right?? -from distutils.dir_util import copy_tree - from os.path import abspath, realpath, dirname, normpath, join as joinpath logger = logging.getLogger(__name__) @@ -188,7 +181,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None): # copy the "src" dir, which includes the previous training job's model and the # custom inference script, to the output of this training job - copy_tree(src_dir, "/opt/ml/model") + shutil.copytree(src_dir, "/opt/ml/model", dirs_exist_ok=True) if __name__ == "__main__": # pragma: no cover diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 841cd68083..36c393969a 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -46,7 +46,7 @@ logger = logging.getLogger(__name__) -FRAMEWORK_VERSION = "0.23-1" +FRAMEWORK_VERSION = "1.2-1" INSTANCE_TYPE = "ml.m5.large" REPACK_SCRIPT = "_repack_model.py" REPACK_SCRIPT_LAUNCHER = "_repack_script_launcher.sh" @@ -329,6 +329,8 @@ def __init__( task=None, skip_model_validation=None, source_uri=None, + model_card=None, + model_life_cycle=None, **kwargs, ): """Constructor of a register model step. @@ -381,6 +383,9 @@ def __init__( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -418,6 +423,8 @@ def __init__( self.container_def_list = container_def_list self.skip_model_validation = skip_model_validation self.source_uri = source_uri + self.model_card = model_card + self.model_life_cycle = model_life_cycle self._properties = Properties( step_name=name, step=self, shape_name="DescribeModelPackageOutput" @@ -493,6 +500,8 @@ def arguments(self) -> RequestType: task=self.task, skip_model_validation=self.skip_model_validation, source_uri=self.source_uri, + model_card=self.model_card, + model_life_cycle=self.model_life_cycle, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index 3678c3d97e..82f76304f5 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -1066,7 +1066,7 @@ def deploy_config_from_estimator( model_name=None, endpoint_name=None, tags=None, - **kwargs + **kwargs, ): """Export Airflow deploy config from a SageMaker estimator diff --git a/src/sagemaker/workflow/automl_step.py b/src/sagemaker/workflow/automl_step.py index 4900da9b98..498c726866 100644 --- a/src/sagemaker/workflow/automl_step.py +++ b/src/sagemaker/workflow/automl_step.py @@ -34,11 +34,11 @@ def __init__( self, name: str, step_args: _JobStepArguments, - display_name: str = None, - description: str = None, - cache_config: CacheConfig = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + cache_config: Optional[CacheConfig] = None, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, - retry_policies: List[RetryPolicy] = None, + retry_policies: Optional[List[RetryPolicy]] = None, ): """Construct a `AutoMLStep`, given a `AutoML` instance. diff --git a/src/sagemaker/workflow/callback_step.py b/src/sagemaker/workflow/callback_step.py index 03903ef908..9a874b6bcc 100644 --- a/src/sagemaker/workflow/callback_step.py +++ b/src/sagemaker/workflow/callback_step.py @@ -84,9 +84,9 @@ def __init__( sqs_queue_url: str, inputs: dict, outputs: List[CallbackOutput], - display_name: str = None, - description: str = None, - cache_config: CacheConfig = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + cache_config: Optional[CacheConfig] = None, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, ): """Constructs a CallbackStep. diff --git a/src/sagemaker/workflow/clarify_check_step.py b/src/sagemaker/workflow/clarify_check_step.py index 11fbb2c00b..52f62dd706 100644 --- a/src/sagemaker/workflow/clarify_check_step.py +++ b/src/sagemaker/workflow/clarify_check_step.py @@ -159,11 +159,11 @@ def __init__( skip_check: Union[bool, PipelineVariable] = False, fail_on_violation: Union[bool, PipelineVariable] = True, register_new_baseline: Union[bool, PipelineVariable] = False, - model_package_group_name: Union[str, PipelineVariable] = None, - supplied_baseline_constraints: Union[str, PipelineVariable] = None, - display_name: str = None, - description: str = None, - cache_config: CacheConfig = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + supplied_baseline_constraints: Optional[Union[str, PipelineVariable]] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + cache_config: Optional[CacheConfig] = None, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, ): """Constructs a ClarifyCheckStep. diff --git a/src/sagemaker/workflow/condition_step.py b/src/sagemaker/workflow/condition_step.py index cfb1606830..e9302f1134 100644 --- a/src/sagemaker/workflow/condition_step.py +++ b/src/sagemaker/workflow/condition_step.py @@ -20,7 +20,6 @@ from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.functions import JsonGet as NewJsonGet -from sagemaker.workflow.step_outputs import StepOutput from sagemaker.workflow.steps import ( Step, StepTypeEnum, @@ -41,11 +40,11 @@ def __init__( self, name: str, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, - display_name: str = None, - description: str = None, - conditions: List[Condition] = None, - if_steps: List[Union[Step, StepCollection, StepOutput]] = None, - else_steps: List[Union[Step, StepCollection, StepOutput]] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + conditions: Optional[List[Condition]] = None, + if_steps: Optional[List[Union[Step, StepCollection]]] = None, + else_steps: Optional[List[Union[Step, StepCollection]]] = None, ): """Construct a ConditionStep for pipelines to support conditional branching. diff --git a/src/sagemaker/workflow/emr_step.py b/src/sagemaker/workflow/emr_step.py index bea2c469f8..293c45bc6c 100644 --- a/src/sagemaker/workflow/emr_step.py +++ b/src/sagemaker/workflow/emr_step.py @@ -161,9 +161,9 @@ def __init__( cluster_id: str, step_config: EMRStepConfig, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, - cache_config: CacheConfig = None, - cluster_config: Dict[str, Any] = None, - execution_role_arn: str = None, + cache_config: Optional[CacheConfig] = None, + cluster_config: Optional[Dict[str, Any]] = None, + execution_role_arn: Optional[str] = None, ): """Constructs an `EMRStep`. diff --git a/src/sagemaker/workflow/fail_step.py b/src/sagemaker/workflow/fail_step.py index fcb3411760..0a2510dd91 100644 --- a/src/sagemaker/workflow/fail_step.py +++ b/src/sagemaker/workflow/fail_step.py @@ -29,9 +29,9 @@ class FailStep(Step): def __init__( self, name: str, - error_message: Union[str, PipelineVariable] = None, - display_name: str = None, - description: str = None, + error_message: Optional[Union[str, PipelineVariable]] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, ): """Constructs a `FailStep`. diff --git a/src/sagemaker/workflow/lambda_step.py b/src/sagemaker/workflow/lambda_step.py index 51e046d5f2..9bcddf3045 100644 --- a/src/sagemaker/workflow/lambda_step.py +++ b/src/sagemaker/workflow/lambda_step.py @@ -84,11 +84,11 @@ def __init__( self, name: str, lambda_func: Lambda, - display_name: str = None, - description: str = None, - inputs: dict = None, - outputs: List[LambdaOutput] = None, - cache_config: CacheConfig = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + inputs: Optional[dict] = None, + outputs: Optional[List[LambdaOutput]] = None, + cache_config: Optional[CacheConfig] = None, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, ): """Constructs a LambdaStep. diff --git a/src/sagemaker/workflow/model_step.py b/src/sagemaker/workflow/model_step.py index 0ef77a5cd0..2b86f3726b 100644 --- a/src/sagemaker/workflow/model_step.py +++ b/src/sagemaker/workflow/model_step.py @@ -302,7 +302,7 @@ def _append_repack_model_step(self): self._repack_model_step_settings.pop("output_kms_key", None) or model.model_kms_key ), - **self._repack_model_step_settings + **self._repack_model_step_settings, ) self.steps.append(repack_model_step) diff --git a/src/sagemaker/workflow/monitor_batch_transform_step.py b/src/sagemaker/workflow/monitor_batch_transform_step.py index aa86103874..d80cda38df 100644 --- a/src/sagemaker/workflow/monitor_batch_transform_step.py +++ b/src/sagemaker/workflow/monitor_batch_transform_step.py @@ -48,8 +48,8 @@ def __init__( check_job_configuration: CheckJobConfig, monitor_before_transform: bool = False, fail_on_violation: Union[bool, PipelineVariable] = True, - supplied_baseline_statistics: Union[str, PipelineVariable] = None, - supplied_baseline_constraints: Union[str, PipelineVariable] = None, + supplied_baseline_statistics: Optional[Union[str, PipelineVariable]] = None, + supplied_baseline_constraints: Optional[Union[str, PipelineVariable]] = None, display_name: Optional[str] = None, description: Optional[str] = None, ): diff --git a/src/sagemaker/workflow/notebook_job_step.py b/src/sagemaker/workflow/notebook_job_step.py index 8a1dd6bc53..8db95a2fae 100644 --- a/src/sagemaker/workflow/notebook_job_step.py +++ b/src/sagemaker/workflow/notebook_job_step.py @@ -13,49 +13,33 @@ """The notebook job step definitions for workflow.""" from __future__ import absolute_import +import os import re import shutil -import os +from typing import Dict, List, Optional, Union -from typing import ( - List, - Optional, - Union, - Dict, +from sagemaker import vpc_utils +from sagemaker.config.config_schema import ( + NOTEBOOK_JOB_ROLE_ARN, + NOTEBOOK_JOB_S3_KMS_KEY_ID, + NOTEBOOK_JOB_S3_ROOT_URI, + NOTEBOOK_JOB_VOLUME_KMS_KEY_ID, + NOTEBOOK_JOB_VPC_CONFIG_SECURITY_GROUP_IDS, + NOTEBOOK_JOB_VPC_CONFIG_SUBNETS, ) - +from sagemaker.s3 import S3Uploader +from sagemaker.s3_utils import s3_path_join +from sagemaker.session import get_execution_role +from sagemaker.utils import Tags, _tmpdir, format_tags, name_from_base, resolve_value_from_config +from sagemaker.workflow.entities import PipelineVariable, RequestType from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.functions import Join from sagemaker.workflow.properties import Properties from sagemaker.workflow.retry import RetryPolicy -from sagemaker.workflow.steps import ( - Step, - ConfigurableRetryStep, - StepTypeEnum, -) from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.step_outputs import StepOutput - -from sagemaker.workflow.entities import ( - RequestType, - PipelineVariable, -) +from sagemaker.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum from sagemaker.workflow.utilities import _collect_parameters, load_step_compilation_context -from sagemaker.session import get_execution_role - -from sagemaker.s3_utils import s3_path_join -from sagemaker.s3 import S3Uploader -from sagemaker.utils import _tmpdir, name_from_base, resolve_value_from_config, format_tags, Tags -from sagemaker import vpc_utils - -from sagemaker.config.config_schema import ( - NOTEBOOK_JOB_ROLE_ARN, - NOTEBOOK_JOB_S3_ROOT_URI, - NOTEBOOK_JOB_S3_KMS_KEY_ID, - NOTEBOOK_JOB_VOLUME_KMS_KEY_ID, - NOTEBOOK_JOB_VPC_CONFIG_SUBNETS, - NOTEBOOK_JOB_VPC_CONFIG_SECURITY_GROUP_IDS, -) # disable E1101 as collect_parameters decorator sets the attributes @@ -259,25 +243,27 @@ def _validate_inputs(self): # input notebook is required if not self.input_notebook or not os.path.isfile(self.input_notebook): errors.append( - f"The required input notebook({self.input_notebook}) is not a valid " f"file." + f"The required input notebook ({self.input_notebook}) is not a valid file." ) # init script is optional if self.initialization_script and not os.path.isfile(self.initialization_script): - errors.append(f"The initialization script({self.input_notebook}) is not a valid file.") + errors.append( + f"The initialization script ({self.initialization_script}) is not a valid file." + ) if self.additional_dependencies: for path in self.additional_dependencies: if not os.path.exists(path): errors.append( - f"The path({path}) specified in additional dependencies does not exist." + f"The path ({path}) specified in additional dependencies does not exist." ) # image uri is required if not self.image_uri or self._region_from_session not in self.image_uri: errors.append( - f"The image uri(specified as {self.image_uri}) is required and " + f"The image uri (specified as {self.image_uri}) is required and " f"should be hosted in same region of the session" - f"({self._region_from_session})." + f" ({self._region_from_session})." ) if not self.kernel_name: @@ -374,7 +360,7 @@ def _prepare_env_variables(self): execution mechanism. """ - job_envs = self.environment_variables if self.environment_variables else {} + job_envs = dict(self.environment_variables or {}) system_envs = { "AWS_DEFAULT_REGION": self._region_from_session, "SM_JOB_DEF_VERSION": "1.0", diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 62167b96e7..f1a62fa637 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -125,6 +125,15 @@ def __init__( self.sagemaker_session.boto_session.client("scheduler"), ) + @property + def latest_pipeline_version_id(self): + """Retrieves the latest version id of this pipeline""" + summaries = self.list_pipeline_versions(max_results=1)["PipelineVersionSummaries"] + if not summaries: + return None + else: + return summaries[0].get("PipelineVersionId") + def create( self, role_arn: str = None, @@ -166,7 +175,8 @@ def create( kwargs, Tags=tags, ) - return self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs) + response = self.sagemaker_session.sagemaker_client.create_pipeline(**kwargs) + return response def _create_args( self, role_arn: str, description: str, parallelism_config: ParallelismConfiguration @@ -214,15 +224,21 @@ def _create_args( ) return kwargs - def describe(self) -> Dict[str, Any]: + def describe(self, pipeline_version_id: int = None) -> Dict[str, Any]: """Describes a Pipeline in the Workflow service. + Args: + pipeline_version_id (Optional[str]): version ID of the pipeline to describe. + Returns: Response dict from the service. See `boto3 client documentation `_ """ - return self.sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=self.name) + kwargs = dict(PipelineName=self.name) + if pipeline_version_id: + kwargs["PipelineVersionId"] = pipeline_version_id + return self.sagemaker_session.sagemaker_client.describe_pipeline(**kwargs) def update( self, @@ -257,7 +273,8 @@ def update( return self.sagemaker_session.sagemaker_client.update_pipeline(self, description) kwargs = self._create_args(role_arn, description, parallelism_config) - return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs) + response = self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs) + return response def upsert( self, @@ -332,6 +349,7 @@ def start( execution_description: str = None, parallelism_config: ParallelismConfiguration = None, selective_execution_config: SelectiveExecutionConfig = None, + pipeline_version_id: int = None, ): """Starts a Pipeline execution in the Workflow service. @@ -345,6 +363,8 @@ def start( over the parallelism configuration of the parent pipeline. selective_execution_config (Optional[SelectiveExecutionConfig]): The configuration for selective step execution. + pipeline_version_id (Optional[str]): version ID of the pipeline to start the execution from. If not + specified, uses the latest version ID. Returns: A `_PipelineExecution` instance, if successful. @@ -366,6 +386,7 @@ def start( PipelineExecutionDisplayName=execution_display_name, ParallelismConfiguration=parallelism_config, SelectiveExecutionConfig=selective_execution_config, + PipelineVersionId=pipeline_version_id, ) if self.sagemaker_session.local_mode: update_args(kwargs, PipelineParameters=parameters) @@ -383,7 +404,11 @@ def start( ) def definition(self) -> str: - """Converts a request structure to string representation for workflow service calls.""" + """Converts a request structure to string representation for workflow service calls. + + Returns: + A JSON formatted string of pipeline definition. + """ compiled_steps = StepsCompiler( pipeline_name=self.name, sagemaker_session=self.sagemaker_session, @@ -457,6 +482,32 @@ def list_executions( if key in response } + def list_pipeline_versions( + self, sort_order: str = None, max_results: int = None, next_token: str = None + ) -> str: + """Lists a pipeline's versions. + + Args: + sort_order (str): The sort order for results (Ascending/Descending). + max_results (int): The maximum number of pipeline executions to return in the response. + next_token (str): If the result of the previous `ListPipelineExecutions` request was + truncated, the response includes a `NextToken`. To retrieve the next set of pipeline + executions, use the token in the next request. + + Returns: + List of Pipeline Version Summaries. See + boto3 client list_pipeline_versions + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/list_pipeline_versions.html# + """ + kwargs = dict(PipelineName=self.name) + update_args( + kwargs, + SortOrder=sort_order, + NextToken=next_token, + MaxResults=max_results, + ) + return self.sagemaker_session.sagemaker_client.list_pipeline_versions(**kwargs) + def _get_latest_execution_arn(self): """Retrieves the latest execution of this pipeline""" response = self.list_executions( @@ -851,7 +902,7 @@ def describe(self): sagemaker.html#SageMaker.Client.describe_pipeline_execution>`_. """ return self.sagemaker_session.sagemaker_client.describe_pipeline_execution( - PipelineExecutionArn=self.arn, + PipelineExecutionArn=self.arn ) def list_steps(self): diff --git a/src/sagemaker/workflow/quality_check_step.py b/src/sagemaker/workflow/quality_check_step.py index c99ce587ac..8ea98e8c65 100644 --- a/src/sagemaker/workflow/quality_check_step.py +++ b/src/sagemaker/workflow/quality_check_step.py @@ -125,12 +125,12 @@ def __init__( skip_check: Union[bool, PipelineVariable] = False, fail_on_violation: Union[bool, PipelineVariable] = True, register_new_baseline: Union[bool, PipelineVariable] = False, - model_package_group_name: Union[str, PipelineVariable] = None, - supplied_baseline_statistics: Union[str, PipelineVariable] = None, - supplied_baseline_constraints: Union[str, PipelineVariable] = None, - display_name: str = None, - description: str = None, - cache_config: CacheConfig = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + supplied_baseline_statistics: Optional[Union[str, PipelineVariable]] = None, + supplied_baseline_constraints: Optional[Union[str, PipelineVariable]] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + cache_config: Optional[CacheConfig] = None, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, ): """Constructs a QualityCheckStep. diff --git a/src/sagemaker/workflow/retry.py b/src/sagemaker/workflow/retry.py index 0df915e8e7..bd8f9cf8c6 100644 --- a/src/sagemaker/workflow/retry.py +++ b/src/sagemaker/workflow/retry.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from enum import Enum -from typing import List +from typing import List, Optional import attr from sagemaker.workflow.entities import Entity, DefaultEnumMeta, RequestType @@ -133,8 +133,8 @@ def __init__( exception_types: List[StepExceptionTypeEnum], backoff_rate: float = 2.0, interval_seconds: int = 1, - max_attempts: int = None, - expire_after_mins: int = None, + max_attempts: Optional[int] = None, + expire_after_mins: Optional[int] = None, ): super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins) for exception_type in exception_types: @@ -177,12 +177,12 @@ class SageMakerJobStepRetryPolicy(RetryPolicy): def __init__( self, - exception_types: List[SageMakerJobExceptionTypeEnum] = None, - failure_reason_types: List[SageMakerJobExceptionTypeEnum] = None, + exception_types: Optional[List[SageMakerJobExceptionTypeEnum]] = None, + failure_reason_types: Optional[List[SageMakerJobExceptionTypeEnum]] = None, backoff_rate: float = 2.0, interval_seconds: int = 1, - max_attempts: int = None, - expire_after_mins: int = None, + max_attempts: Optional[int] = None, + expire_after_mins: Optional[int] = None, ): super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins) diff --git a/src/sagemaker/workflow/selective_execution_config.py b/src/sagemaker/workflow/selective_execution_config.py index dea08d2b8c..5c400937a6 100644 --- a/src/sagemaker/workflow/selective_execution_config.py +++ b/src/sagemaker/workflow/selective_execution_config.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. """Pipeline Parallelism Configuration""" from __future__ import absolute_import -from typing import List +from typing import List, Optional from sagemaker.workflow.entities import RequestType @@ -25,8 +25,8 @@ class SelectiveExecutionConfig: def __init__( self, selected_steps: List[str], - source_pipeline_execution_arn: str = None, reference_latest_execution: bool = True, + source_pipeline_execution_arn: Optional[str] = None, ): """Create a `SelectiveExecutionConfig`. diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 0eedf4aa96..a1d939254c 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -72,11 +72,11 @@ def __init__( response_types, inference_instances=None, transform_instances=None, - estimator: EstimatorBase = None, + estimator: Optional[EstimatorBase] = None, model_data=None, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, - repack_model_step_retry_policies: List[RetryPolicy] = None, - register_model_step_retry_policies: List[RetryPolicy] = None, + repack_model_step_retry_policies: Optional[List[RetryPolicy]] = None, + register_model_step_retry_policies: Optional[List[RetryPolicy]] = None, model_package_group_name=None, model_metrics=None, approval_status=None, @@ -85,7 +85,7 @@ def __init__( display_name=None, description=None, tags=None, - model: Union[Model, PipelineModel] = None, + model: Optional[Union[Model, PipelineModel]] = None, drift_check_baselines=None, customer_metadata_properties=None, domain=None, @@ -97,6 +97,8 @@ def __init__( data_input_configuration=None, skip_model_validation=None, source_uri=None, + model_card=None, + model_life_cycle=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -155,7 +157,9 @@ def __init__( skip_model_validation (str): Indicates if you want to skip model validation. Values can be "All" or "None" (default: None). source_uri (str): The URI of the source for the model package (default: None). - + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). **kwargs: additional arguments to `create_model`. """ super().__init__(name=name, depends_on=depends_on) @@ -294,6 +298,8 @@ def __init__( task=task, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + model_life_cycle=model_life_cycle, **kwargs, ) if not repack_model: @@ -325,8 +331,8 @@ def __init__( instance_count, instance_type, transform_inputs, - description: str = None, - display_name: str = None, + description: Optional[str] = None, + display_name: Optional[str] = None, # model arguments image_uri=None, predictor_cls=None, @@ -343,9 +349,9 @@ def __init__( volume_kms_key=None, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, # step retry policies - repack_model_step_retry_policies: List[RetryPolicy] = None, - model_step_retry_policies: List[RetryPolicy] = None, - transform_step_retry_policies: List[RetryPolicy] = None, + repack_model_step_retry_policies: Optional[List[RetryPolicy]] = None, + model_step_retry_policies: Optional[List[RetryPolicy]] = None, + transform_step_retry_policies: Optional[List[RetryPolicy]] = None, **kwargs, ): """Construct steps required for a Transformer step collection: diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 44e0b34e54..dbc37371db 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -18,7 +18,6 @@ from enum import Enum from typing import Dict, List, Set, Union, Optional, Any, TYPE_CHECKING -from urllib.parse import urlparse import attr @@ -362,10 +361,10 @@ def __init__( self, name: str, step_type: StepTypeEnum, - display_name: str = None, - description: str = None, - depends_on: Optional[List[Union[str, Step, "StepCollection", StepOutput]]] = None, - retry_policies: List[RetryPolicy] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, + retry_policies: Optional[List[RetryPolicy]] = None, ): super().__init__( name=name, @@ -404,14 +403,14 @@ class TrainingStep(ConfigurableRetryStep): def __init__( self, name: str, - step_args: _JobStepArguments = None, - estimator: EstimatorBase = None, - display_name: str = None, - description: str = None, - inputs: Union[TrainingInput, dict, str, FileSystemInput] = None, - cache_config: CacheConfig = None, + step_args: Optional[_JobStepArguments] = None, + estimator: Optional[EstimatorBase] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + inputs: Optional[Union[TrainingInput, dict, str, FileSystemInput]] = None, + cache_config: Optional[CacheConfig] = None, depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, - retry_policies: List[RetryPolicy] = None, + retry_policies: Optional[List[RetryPolicy]] = None, ): """Construct a `TrainingStep`, given an `EstimatorBase` instance. @@ -465,6 +464,7 @@ def __init__( self.step_args = step_args self.estimator = estimator self.inputs = inputs + self.job_name = None self._properties = Properties( step_name=name, step=self, shape_name="DescribeTrainingJobResponse" @@ -493,19 +493,6 @@ def __init__( DeprecationWarning, ) - self.job_name = None - if estimator and (estimator.source_dir or estimator.entry_point): - # By default, `Estimator` will upload the local code to an S3 path - # containing a timestamp. This causes cache misses whenever a - # pipeline is updated, even if the underlying script hasn't changed. - # To avoid this, hash the contents of the training script and include it - # in the `job_name` passed to the `Estimator`, which will be used - # instead of the timestamped path. - if not is_pipeline_variable(estimator.source_dir) and not is_pipeline_variable( - estimator.entry_point - ): - self.job_name = self._generate_code_upload_path() - @property def arguments(self) -> RequestType: """The arguments dictionary that is used to call `create_training_job`. @@ -554,26 +541,6 @@ def to_request(self) -> RequestType: return request_dict - def _generate_code_upload_path(self) -> str or None: - """Generate an upload path for local training scripts based on their content.""" - from sagemaker.workflow.utilities import hash_files_or_dirs - - if self.estimator.source_dir: - source_dir_url = urlparse(self.estimator.source_dir) - if source_dir_url.scheme == "" or source_dir_url.scheme == "file": - code_hash = hash_files_or_dirs( - [self.estimator.source_dir] + self.estimator.dependencies - ) - return f"{self.name}-{code_hash}"[:1024] - elif self.estimator.entry_point: - entry_point_url = urlparse(self.estimator.entry_point) - if entry_point_url.scheme == "" or entry_point_url.scheme == "file": - code_hash = hash_files_or_dirs( - [self.estimator.entry_point] + self.estimator.dependencies - ) - return f"{self.name}-{code_hash}"[:1024] - return None - class CreateModelStep(ConfigurableRetryStep): """`CreateModelStep` for SageMaker Pipelines Workflows.""" @@ -645,6 +612,7 @@ def arguments(self) -> RequestType: request_dict = self.step_args else: if isinstance(self.model, PipelineModel): + self.model._init_sagemaker_session_if_does_not_exist() request_dict = self.model.sagemaker_session._create_model_request( name="", role=self.model.role, @@ -653,6 +621,7 @@ def arguments(self) -> RequestType: enable_network_isolation=self.model.enable_network_isolation, ) else: + self.model._init_sagemaker_session_if_does_not_exist() request_dict = self.model.sagemaker_session._create_model_request( name="", role=self.model.role, @@ -681,14 +650,14 @@ class TransformStep(ConfigurableRetryStep): def __init__( self, name: str, - step_args: _JobStepArguments = None, - transformer: Transformer = None, - inputs: TransformInput = None, - display_name: str = None, - description: str = None, - cache_config: CacheConfig = None, + step_args: Optional[_JobStepArguments] = None, + transformer: Optional[Transformer] = None, + inputs: Optional[TransformInput] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + cache_config: Optional[CacheConfig] = None, depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, - retry_policies: List[RetryPolicy] = None, + retry_policies: Optional[List[RetryPolicy]] = None, ): """Constructs a `TransformStep`, given a `Transformer` instance. @@ -808,19 +777,19 @@ class ProcessingStep(ConfigurableRetryStep): def __init__( self, name: str, - step_args: _JobStepArguments = None, - processor: Processor = None, - display_name: str = None, - description: str = None, - inputs: List[ProcessingInput] = None, - outputs: List[ProcessingOutput] = None, - job_arguments: List[str] = None, - code: str = None, - property_files: List[PropertyFile] = None, - cache_config: CacheConfig = None, + step_args: Optional[_JobStepArguments] = None, + processor: Optional[Processor] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + inputs: Optional[List[ProcessingInput]] = None, + outputs: Optional[List[ProcessingOutput]] = None, + job_arguments: Optional[List[str]] = None, + code: Optional[str] = None, + property_files: Optional[List[PropertyFile]] = None, + cache_config: Optional[CacheConfig] = None, depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, - retry_policies: List[RetryPolicy] = None, - kms_key=None, + retry_policies: Optional[List[RetryPolicy]] = None, + kms_key: Optional[str] = None, ): """Construct a `ProcessingStep`, given a `Processor` instance. @@ -893,16 +862,6 @@ def __init__( "code argument has to be a valid S3 URI or local file path " + "rather than a pipeline variable" ) - code_url = urlparse(code) - if code_url.scheme == "" or code_url.scheme == "file": - # By default, `Processor` will upload the local code to an S3 path - # containing a timestamp. This causes cache misses whenever a - # pipeline is updated, even if the underlying script hasn't changed. - # To avoid this, hash the contents of the script and include it - # in the `job_name` passed to the `Processor`, which will be used - # instead of the timestamped path. - self.job_name = self._generate_code_upload_path() - warnings.warn( ( 'We are deprecating the instantiation of ProcessingStep using "processor".' @@ -980,15 +939,15 @@ class TuningStep(ConfigurableRetryStep): def __init__( self, name: str, - step_args: _JobStepArguments = None, - tuner: HyperparameterTuner = None, - display_name: str = None, - description: str = None, + step_args: Optional[_JobStepArguments] = None, + tuner: Optional[HyperparameterTuner] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, inputs=None, - job_arguments: List[str] = None, - cache_config: CacheConfig = None, + job_arguments: Optional[List[str]] = None, + cache_config: Optional[CacheConfig] = None, depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None, - retry_policies: List[RetryPolicy] = None, + retry_policies: Optional[List[RetryPolicy]] = None, ): """Construct a `TuningStep`, given a `HyperparameterTuner` instance. diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index 4ef5ad5dd2..961972da4d 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -21,7 +21,15 @@ import hashlib from urllib.parse import unquote, urlparse from contextlib import contextmanager -from _hashlib import HASH as Hash + +try: + # _hashlib is an internal python module, and is not present in + # statically linked interpreters. + from _hashlib import HASH as Hash +except ImportError: + import typing + + Hash = typing.Any from sagemaker.utils import base_from_name from sagemaker.workflow.parameters import Parameter @@ -268,29 +276,29 @@ def get_config_hash(step: Entity): def hash_object(obj) -> str: - """Get the MD5 hash of an object. + """Get the SHA256 hash of an object. Args: obj (dict): The object Returns: - str: The MD5 hash of the object + str: The SHA256 hash of the object """ - return hashlib.md5(str(obj).encode()).hexdigest() + return hashlib.sha256(str(obj).encode()).hexdigest() def hash_file(path: str) -> str: - """Get the MD5 hash of a file. + """Get the SHA256 hash of a file. Args: path (str): The local path for the file. Returns: - str: The MD5 hash of the file. + str: The SHA256 hash of the file. """ - return _hash_file(path, hashlib.md5()).hexdigest() + return _hash_file(path, hashlib.sha256()).hexdigest() def hash_files_or_dirs(paths: List[str]) -> str: - """Get the MD5 hash of the contents of a list of files or directories. + """Get the SHA256 hash of the contents of a list of files or directories. Hash is changed if: * input list is changed @@ -301,58 +309,58 @@ def hash_files_or_dirs(paths: List[str]) -> str: Args: paths: List of file or directory paths Returns: - str: The MD5 hash of the list of files or directories. + str: The SHA256 hash of the list of files or directories. """ - md5 = hashlib.md5() + sha256 = hashlib.sha256() for path in sorted(paths): - md5 = _hash_file_or_dir(path, md5) - return md5.hexdigest() + sha256 = _hash_file_or_dir(path, sha256) + return sha256.hexdigest() -def _hash_file_or_dir(path: str, md5: Hash) -> Hash: +def _hash_file_or_dir(path: str, sha256: Hash) -> Hash: """Updates the inputted Hash with the contents of the current path. Args: path: path of file or directory Returns: - str: The MD5 hash of the file or directory + str: The SHA256 hash of the file or directory """ if isinstance(path, str) and path.lower().startswith("file://"): path = unquote(urlparse(path).path) - md5.update(path.encode()) + sha256.update(path.encode()) if Path(path).is_dir(): - md5 = _hash_dir(path, md5) + sha256 = _hash_dir(path, sha256) elif Path(path).is_file(): - md5 = _hash_file(path, md5) - return md5 + sha256 = _hash_file(path, sha256) + return sha256 -def _hash_dir(directory: Union[str, Path], md5: Hash) -> Hash: +def _hash_dir(directory: Union[str, Path], sha256: Hash) -> Hash: """Updates the inputted Hash with the contents of the current path. Args: directory: path of the directory Returns: - str: The MD5 hash of the directory + str: The SHA256 hash of the directory """ if not Path(directory).is_dir(): raise ValueError(str(directory) + " is not a valid directory") for path in sorted(Path(directory).iterdir()): - md5.update(path.name.encode()) + sha256.update(path.name.encode()) if path.is_file(): - md5 = _hash_file(path, md5) + sha256 = _hash_file(path, sha256) elif path.is_dir(): - md5 = _hash_dir(path, md5) - return md5 + sha256 = _hash_dir(path, sha256) + return sha256 -def _hash_file(file: Union[str, Path], md5: Hash) -> Hash: +def _hash_file(file: Union[str, Path], sha256: Hash) -> Hash: """Updates the inputted Hash with the contents of the current path. Args: file: path of the file Returns: - str: The MD5 hash of the file + str: The SHA256 hash of the file """ if isinstance(file, str) and file.lower().startswith("file://"): file = unquote(urlparse(file).path) @@ -363,8 +371,8 @@ def _hash_file(file: Union[str, Path], md5: Hash) -> Hash: data = f.read(BUF_SIZE) if not data: break - md5.update(data) - return md5 + sha256.update(data) + return sha256 def validate_step_args_input( diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index dfd7145e93..9385acf745 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -51,7 +51,7 @@ def __init__( py_version: str = "py3", image_uri: Optional[Union[str, PipelineVariable]] = None, image_uri_region: Optional[str] = None, - **kwargs + **kwargs, ): """An estimator that executes an XGBoost-based SageMaker Training Job. @@ -78,8 +78,8 @@ def __init__( source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI) to a directory with any other training source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when training on Amazon SageMaker. + point to a file with name ``sourcedir.tar.gz``. Structure within this directory + are preserved when training on Amazon SageMaker. hyperparameters (dict[str, str] or dict[str, PipelineVariable]): Hyperparameters that will be used for training (default: None). The hyperparameters are made accessible as a dict[str, str] to the training code @@ -137,7 +137,7 @@ def create_model( entry_point=None, source_dir=None, dependencies=None, - **kwargs + **kwargs, ): """Create a SageMaker ``XGBoostModel`` object that can be deployed to an ``Endpoint``. @@ -188,7 +188,7 @@ def create_model( sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), dependencies=(dependencies or self.dependencies), - **kwargs + **kwargs, ) @classmethod diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 8101f32721..f4797c79e7 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -23,6 +23,10 @@ from sagemaker.fw_utils import model_code_key_prefix from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME +from sagemaker.model_card import ( + ModelCard, + ModelPackageModelCard, +) from sagemaker.predictor import Predictor from sagemaker.serializers import LibSVMSerializer from sagemaker.utils import to_string @@ -30,6 +34,7 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.xgboost.defaults import XGBOOST_NAME from sagemaker.xgboost.utils import validate_py_version, validate_framework_version +from sagemaker.model_life_cycle import ModelLifeCycle logger = logging.getLogger("sagemaker") @@ -86,9 +91,9 @@ def __init__( framework_version: str = None, image_uri: Optional[Union[str, PipelineVariable]] = None, py_version: str = "py3", - predictor_cls: callable = XGBoostPredictor, + predictor_cls: Optional[Callable] = XGBoostPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, - **kwargs + **kwargs, ): """Initialize an XGBoostModel. @@ -108,8 +113,8 @@ def __init__( (default: 'py3'). framework_version (str): XGBoost version you want to use for executing your model training code. - predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create - a predictor with an endpoint name and SageMaker ``Session``. + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call + to create a predictor with an endpoint name and SageMaker ``Session``. If specified, ``deploy()`` returns the result of invoking this function on the created endpoint name. model_server_workers (int or PipelineVariable): Optional. The number of worker processes @@ -160,6 +165,8 @@ def register( data_input_configuration: Optional[Union[str, PipelineVariable]] = None, skip_model_validation: Optional[Union[str, PipelineVariable]] = None, source_uri: Optional[Union[str, PipelineVariable]] = None, + model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None, + model_life_cycle: Optional[ModelLifeCycle] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -211,6 +218,9 @@ def register( validation. Values can be "All" or "None" (default: None). source_uri (str or PipelineVariable): The URI of the source for the model package (default: None). + model_card (ModeCard or ModelPackageModelCard): document contains qualitative and + quantitative information about a model (default: None). + model_life_cycle (ModelLifeCycle): ModelLifeCycle object (default: None). Returns: str: A string of SageMaker Model Package ARN. @@ -251,6 +261,8 @@ def register( data_input_configuration=data_input_configuration, skip_model_validation=skip_model_validation, source_uri=source_uri, + model_card=model_card, + model_life_cycle=model_life_cycle, ) def prepare_container_def( @@ -259,6 +271,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition with framework configuration. @@ -306,6 +319,7 @@ def prepare_container_def( model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): diff --git a/tests/conftest.py b/tests/conftest.py index 0309781e7b..34f5c5306d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ from botocore.config import Config from packaging.version import Version +from packaging.specifiers import SpecifierSet from sagemaker import Session, image_uris, utils, get_execution_role from sagemaker.local import LocalSession @@ -253,7 +254,11 @@ def mxnet_eia_latest_py_version(): @pytest.fixture(scope="module", params=["py2", "py3"]) def pytorch_training_py_version(pytorch_training_version, request): - if Version(pytorch_training_version) >= Version("2.0"): + if Version(pytorch_training_version) >= Version("2.6"): + return "py312" + if Version(pytorch_training_version) >= Version("2.3"): + return "py311" + elif Version(pytorch_training_version) >= Version("2.0"): return "py310" elif Version(pytorch_training_version) >= Version("1.13"): return "py39" @@ -267,7 +272,11 @@ def pytorch_training_py_version(pytorch_training_version, request): @pytest.fixture(scope="module", params=["py2", "py3"]) def pytorch_inference_py_version(pytorch_inference_version, request): - if Version(pytorch_inference_version) >= Version("2.0"): + if Version(pytorch_inference_version) >= Version("2.6"): + return "py312" + elif Version(pytorch_inference_version) >= Version("2.3"): + return "py311" + elif Version(pytorch_inference_version) >= Version("2.0"): return "py310" elif Version(pytorch_inference_version) >= Version("1.13"): return "py39" @@ -288,6 +297,8 @@ def huggingface_pytorch_training_version(huggingface_training_version): @pytest.fixture(scope="module") def huggingface_pytorch_training_py_version(huggingface_pytorch_training_version): + if Version(huggingface_pytorch_training_version) >= Version("2.3"): + return "py311" if Version(huggingface_pytorch_training_version) >= Version("2.0"): return "py310" elif Version(huggingface_pytorch_training_version) >= Version("1.13"): @@ -350,6 +361,8 @@ def huggingface_training_compiler_pytorch_py_version( def huggingface_pytorch_latest_training_py_version( huggingface_training_pytorch_latest_version, ): + if Version(huggingface_training_pytorch_latest_version) >= Version("2.3"): + return "py311" if Version(huggingface_training_pytorch_latest_version) >= Version("2.0"): return "py310" elif Version(huggingface_training_pytorch_latest_version) >= Version("1.13"): @@ -541,7 +554,9 @@ def _tf_py_version(tf_version, request): return "py38" if Version("2.8") <= version < Version("2.12"): return "py39" - return "py310" + if Version("2.12") <= version < Version("2.19"): + return "py310" + return "py312" @pytest.fixture(scope="module") @@ -551,11 +566,18 @@ def tf_full_version(tensorflow_training_latest_version, tensorflow_inference_lat Fixture exists as such, since TF training and TFS have different latest versions. Otherwise, this would simply be a single latest version. """ - return str( - min( - Version(tensorflow_training_latest_version), - Version(tensorflow_inference_latest_version), - ) + tensorflow_training_latest_version = Version(tensorflow_training_latest_version) + tensorflow_inference_latest_version = Version(tensorflow_inference_latest_version) + + return_version = min( + tensorflow_training_latest_version, + tensorflow_inference_latest_version, + ) + + return ( + f"{return_version.major}.{return_version.minor}" + if return_version in SpecifierSet(">=2.16") + else str(return_version) ) @@ -577,7 +599,9 @@ def tf_full_py_version(tf_full_version): return "py38" if version < Version("2.12"): return "py39" - return "py310" + if version < Version("2.19"): + return "py310" + return "py312" @pytest.fixture(scope="module") diff --git a/tests/data/_repack_model.py b/tests/data/_repack_model.py index 3cfa6760b3..b370db5dbf 100644 --- a/tests/data/_repack_model.py +++ b/tests/data/_repack_model.py @@ -26,13 +26,6 @@ # is unpacked for inference, the custom entry point will be used. # Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html -# distutils.dir_util.copy_tree works way better than the half-baked -# shutil.copytree which bombs on previously existing target dirs... -# alas ... https://bugs.python.org/issue10948 -# we'll go ahead and use the copy_tree function anyways because this -# repacking is some short-lived hackery, right?? -from distutils.dir_util import copy_tree - def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover """Repack custom dependencies and code into an existing model TAR archive @@ -92,7 +85,7 @@ def repack(inference_script, model_archive, dependencies=None, source_dir=None): # copy the "src" dir, which includes the previous training job's model and the # custom inference script, to the output of this training job - copy_tree(src_dir, "/opt/ml/model") + shutil.copytree(src_dir, "/opt/ml/model", dirs_exist_ok=True) if __name__ == "__main__": # pragma: no cover diff --git a/tests/data/marketplace/iris/scoring_logic.py b/tests/data/marketplace/iris/scoring_logic.py index f9e2f1bb35..c1a4cc1642 100644 --- a/tests/data/marketplace/iris/scoring_logic.py +++ b/tests/data/marketplace/iris/scoring_logic.py @@ -3,7 +3,7 @@ import logging import re from flask import Flask -from flask import request +from flask import request, escape from joblib import dump, load import numpy as np import os @@ -73,37 +73,34 @@ def endpoint_ping(): # Create a path for inference @app.route("/invocations", methods=["POST"]) def endpoint_invocations(): - try: - logger.info(f"Processing request: {request.headers}") - logger.debug(f"Payload: {request.headers}") - - if request.content_type not in SUPPORTED_REQUEST_MIMETYPES: - logger.error(f"Unsupported Content-Type specified: {request.content_type}") - return f"Invalid Content-Type. Supported Content-Types: {', '.join(SUPPORTED_REQUEST_MIMETYPES)}" - elif request.content_type == "text/csv": - # Step 1: Decode payload into input format expected by model - data = request.get_data().decode("utf8") - # Step 2: Perform inference with the loaded model - predictions = model.predict_from_csv(data) - elif request.content_type == "application/json": - data = request.get_data().decode("utf8") - predictions = model.predict_from_json(data) - elif request.content_type == "application/jsonlines": - data = request.get_data().decode("utf8") - predictions = model.predict_from_jsonlines(data) - - # Step 3: Process predictions into the specified response type (if specified) - response_mimetype = request.accept_mimetypes.best_match( - SUPPORTED_RESPONSE_MIMETYPES, default="application/json" - ) - - if response_mimetype == "text/csv": - response = "\n".join(predictions) - elif response_mimetype == "application/jsonlines": - response = "\n".join([json.dumps({"class": pred}) for pred in predictions]) - elif response_mimetype == "application/json": - response = json.dumps({"predictions": [{"class": pred} for pred in predictions]}) - - return response - except Exception as e: - return f"Error during model invocation: {str(e)} for input: {request.get_data()}" + logger.info(f"Processing request: {request.headers}") + logger.debug(f"Payload: {request.headers}") + + if request.content_type not in SUPPORTED_REQUEST_MIMETYPES: + logger.error(f"Unsupported Content-Type specified: {request.content_type}") + return f"Invalid Content-Type. Supported Content-Types: {', '.join(SUPPORTED_REQUEST_MIMETYPES)}" + elif request.content_type == "text/csv": + # Step 1: Decode payload into input format expected by model + data = request.get_data().decode("utf8") + # Step 2: Perform inference with the loaded model + predictions = model.predict_from_csv(data) + elif request.content_type == "application/json": + data = request.get_data().decode("utf8") + predictions = model.predict_from_json(data) + elif request.content_type == "application/jsonlines": + data = request.get_data().decode("utf8") + predictions = model.predict_from_jsonlines(data) + + # Step 3: Process predictions into the specified response type (if specified) + response_mimetype = request.accept_mimetypes.best_match( + SUPPORTED_RESPONSE_MIMETYPES, default="application/json" + ) + + if response_mimetype == "text/csv": + response = "\n".join(predictions) + elif response_mimetype == "application/jsonlines": + response = "\n".join([json.dumps({"class": pred}) for pred in predictions]) + elif response_mimetype == "application/json": + response = json.dumps({"predictions": [{"class": pred} for pred in predictions]}) + + return response diff --git a/tests/data/modules/custom_drivers/driver.py b/tests/data/modules/custom_drivers/driver.py new file mode 100644 index 0000000000..3395b80da9 --- /dev/null +++ b/tests/data/modules/custom_drivers/driver.py @@ -0,0 +1,34 @@ +import json +import os +import subprocess +import sys + + +def main(): + driver_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + process_count_per_node = driver_config["process_count_per_node"] + assert process_count_per_node != None + + hps = json.loads(os.environ["SM_HPS"]) + assert hps != None + assert isinstance(hps, dict) + + source_dir = os.environ["SM_SOURCE_DIR"] + assert source_dir == "/opt/ml/input/data/code" + sm_drivers_dir = os.environ["SM_DISTRIBUTED_DRIVER_DIR"] + assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/distributed_drivers" + + entry_script = os.environ["SM_ENTRY_SCRIPT"] + assert entry_script != None + + python = sys.executable + + command = [python, entry_script] + print(f"Running command: {command}") + subprocess.run(command, check=True) + + +if __name__ == "__main__": + print("Running custom driver script") + main() + print("Finished running custom driver script") diff --git a/tests/data/modules/local_script/data/test/x_test.npy b/tests/data/modules/local_script/data/test/x_test.npy new file mode 100644 index 0000000000..a9977e39c0 Binary files /dev/null and b/tests/data/modules/local_script/data/test/x_test.npy differ diff --git a/tests/data/modules/local_script/data/test/y_test.npy b/tests/data/modules/local_script/data/test/y_test.npy new file mode 100644 index 0000000000..a7191945ee Binary files /dev/null and b/tests/data/modules/local_script/data/test/y_test.npy differ diff --git a/tests/data/modules/local_script/data/train/x_train.npy b/tests/data/modules/local_script/data/train/x_train.npy new file mode 100644 index 0000000000..d267502e65 Binary files /dev/null and b/tests/data/modules/local_script/data/train/x_train.npy differ diff --git a/tests/data/modules/local_script/data/train/y_train.npy b/tests/data/modules/local_script/data/train/y_train.npy new file mode 100644 index 0000000000..b8c17c4972 Binary files /dev/null and b/tests/data/modules/local_script/data/train/y_train.npy differ diff --git a/tests/data/modules/local_script/local_training_script.py b/tests/data/modules/local_script/local_training_script.py new file mode 100644 index 0000000000..6bb73343c0 --- /dev/null +++ b/tests/data/modules/local_script/local_training_script.py @@ -0,0 +1,147 @@ +# flake8: noqa +import argparse +import numpy as np +import os +import sys +import logging +import json +import shutil +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset +from pytorch_model_def import get_model + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) +current_dir = os.path.dirname(os.path.abspath(__file__)) +data_dir = "/opt/ml/input/data" + + +def get_train_data(train_dir): + """ + Get the training data and convert to tensors + """ + + x_train = np.load(os.path.join(train_dir, "x_train.npy")) + y_train = np.load(os.path.join(train_dir, "y_train.npy")) + logger.info(f"x train: {x_train.shape}, y train: {y_train.shape}") + + return torch.from_numpy(x_train), torch.from_numpy(y_train) + + +def get_test_data(test_dir): + """ + Get the testing data and convert to tensors + """ + + x_test = np.load(os.path.join(test_dir, "x_test.npy")) + y_test = np.load(os.path.join(test_dir, "y_test.npy")) + logger.info(f"x test: {x_test.shape}, y test: {y_test.shape}") + + return torch.from_numpy(x_test), torch.from_numpy(y_test) + + +def model_fn(model_dir): + """ + Load the model for inference + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model() + model.load_state_dict(torch.load(model_dir + "/model.pth")) + model.eval() + return model.to(device) + + +def input_fn(request_body, request_content_type): + """ + Deserialize and prepare the prediction input + """ + + if request_content_type == "application/json": + request = json.loads(request_body) + train_inputs = torch.tensor(request) + return train_inputs + + +def predict_fn(input_data, model): + """ + Apply model to the incoming request + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + with torch.no_grad(): + return model(input_data.float()).numpy()[0] + + +def train(): + """ + Train the PyTorch model + """ + # Directories: train, test and model + train_dir = os.path.join(data_dir, "train") + test_dir = os.path.join(data_dir, "test") + model_dir = os.environ.get("SM_MODEL_DIR", os.path.join(current_dir, "data/model")) + + # Load the training and testing data + x_train, y_train = get_train_data(train_dir) + x_test, y_test = get_test_data(test_dir) + train_ds = TensorDataset(x_train, y_train) + + # Training parameters - used to configure the training loop + batch_size = 64 + epochs = 1 + learning_rate = 0.1 + logger.info( + "batch_size = {}, epochs = {}, learning rate = {}".format(batch_size, epochs, learning_rate) + ) + + train_dl = DataLoader(train_ds, batch_size, shuffle=True) + + # Define the model, loss function and optimizer + model = get_model() + model = model.to(device) + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + + # Train the model + for epoch in range(epochs): + for x_train_batch, y_train_batch in train_dl: + y = model(x_train_batch.float()) + loss = criterion(y.flatten(), y_train_batch.float()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + epoch += 1 + logger.info(f"epoch: {epoch} -> loss: {loss}") + + # Test the model + with torch.no_grad(): + y = model(x_test.float()).flatten() + mse = ((y - y_test) ** 2).sum() / y_test.shape[0] + print("\nTest MSE:", mse.numpy()) + + # Save the model + os.makedirs(model_dir, exist_ok=True) + torch.save(model.state_dict(), model_dir + "/model.pth") + inference_code_path = model_dir + "/code/" + + if not os.path.exists(inference_code_path): + os.mkdir(inference_code_path) + logger.info("Created a folder at {}!".format(inference_code_path)) + + shutil.copy("local_training_script.py", inference_code_path) + shutil.copy("pytorch_model_def.py", inference_code_path) + logger.info("Saving models files to {}".format(inference_code_path)) + + +if __name__ == "__main__": + print("Running the training job ...\n") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + train() diff --git a/tests/data/modules/local_script/pytorch_model_def.py b/tests/data/modules/local_script/pytorch_model_def.py new file mode 100644 index 0000000000..2440b22f88 --- /dev/null +++ b/tests/data/modules/local_script/pytorch_model_def.py @@ -0,0 +1,23 @@ +# flake8: noqa +import torch +import torch.nn as nn + + +class NeuralNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(8, 8) + self.fc2 = nn.Linear(8, 6) + self.fc3 = nn.Linear(6, 1) + + def forward(self, x): + x = torch.tanh(self.fc1(x)) + x = torch.sigmoid(self.fc2(x)) + x = self.fc3(x) + return x + + +def get_model(): + + model = NeuralNet() + return model diff --git a/tests/data/modules/params_script/hyperparameters.json b/tests/data/modules/params_script/hyperparameters.json new file mode 100644 index 0000000000..f637288dbe --- /dev/null +++ b/tests/data/modules/params_script/hyperparameters.json @@ -0,0 +1,15 @@ +{ + "integer": 1, + "boolean": true, + "float": 3.14, + "string": "Hello World", + "list": [1, 2, 3], + "dict": { + "string": "value", + "integer": 3, + "float": 3.14, + "list": [1, 2, 3], + "dict": {"key": "value"}, + "boolean": true + } +} \ No newline at end of file diff --git a/tests/data/modules/params_script/hyperparameters.yaml b/tests/data/modules/params_script/hyperparameters.yaml new file mode 100644 index 0000000000..9e3011daf2 --- /dev/null +++ b/tests/data/modules/params_script/hyperparameters.yaml @@ -0,0 +1,19 @@ +integer: 1 +boolean: true +float: 3.14 +string: "Hello World" +list: + - 1 + - 2 + - 3 +dict: + string: value + integer: 3 + float: 3.14 + list: + - 1 + - 2 + - 3 + dict: + key: value + boolean: true \ No newline at end of file diff --git a/tests/data/modules/params_script/requirements.txt b/tests/data/modules/params_script/requirements.txt new file mode 100644 index 0000000000..3d2e72e354 --- /dev/null +++ b/tests/data/modules/params_script/requirements.txt @@ -0,0 +1 @@ +omegaconf diff --git a/tests/data/modules/params_script/train.py b/tests/data/modules/params_script/train.py new file mode 100644 index 0000000000..9b8cb2c82f --- /dev/null +++ b/tests/data/modules/params_script/train.py @@ -0,0 +1,232 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Script to test hyperparameters contract.""" +from __future__ import absolute_import + +import argparse +import json +import os +from typing import List, Dict, Any +from dataclasses import dataclass +from omegaconf import OmegaConf + +EXPECTED_HYPERPARAMETERS = { + "integer": 1, + "boolean": True, + "float": 3.14, + "string": "Hello World", + "list": [1, 2, 3], + "dict": { + "string": "value", + "integer": 3, + "float": 3.14, + "list": [1, 2, 3], + "dict": {"key": "value"}, + "boolean": True, + }, +} + + +def parse_args(): + parser = argparse.ArgumentParser(description="Test Hyperparameters") + parser.add_argument( + "--string", + type=str, + default=None, + required=True, + ) + parser.add_argument( + "--integer", + type=int, + default=None, + required=True, + ) + parser.add_argument( + "--float", + type=float, + default=None, + required=True, + ) + parser.add_argument( + "--boolean", + type=lambda x: json.loads(x), + default=None, + required=True, + ) + parser.add_argument( + "--list", + type=lambda x: json.loads(x), + default=None, + required=True, + ) + parser.add_argument( + "--dict", + type=lambda x: json.loads(x), + default=None, + required=True, + ) + return parser.parse_args() + + +def main(): + args = parse_args() + print(args) + + assert isinstance(args.string, str) + assert isinstance(args.integer, int) + assert isinstance(args.boolean, bool) + assert isinstance(args.float, float) + assert isinstance(args.list, list) + assert isinstance(args.dict, dict) + + assert args.string == EXPECTED_HYPERPARAMETERS["string"] + assert args.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert args.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert args.float == EXPECTED_HYPERPARAMETERS["float"] + assert args.list == EXPECTED_HYPERPARAMETERS["list"] + assert args.dict == EXPECTED_HYPERPARAMETERS["dict"] + + assert os.environ["SM_HP_STRING"] == EXPECTED_HYPERPARAMETERS["string"] + assert int(os.environ["SM_HP_INTEGER"]) == EXPECTED_HYPERPARAMETERS["integer"] + assert float(os.environ["SM_HP_FLOAT"]) == EXPECTED_HYPERPARAMETERS["float"] + assert json.loads(os.environ["SM_HP_BOOLEAN"]) == EXPECTED_HYPERPARAMETERS["boolean"] + assert json.loads(os.environ["SM_HP_LIST"]) == EXPECTED_HYPERPARAMETERS["list"] + assert json.loads(os.environ["SM_HP_DICT"]) == EXPECTED_HYPERPARAMETERS["dict"] + + params = json.loads(os.environ["SM_HPS"]) + print(f"SM_HPS: {params}") + assert params["string"] == EXPECTED_HYPERPARAMETERS["string"] + assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"] + assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"] + assert params["float"] == EXPECTED_HYPERPARAMETERS["float"] + assert params["list"] == EXPECTED_HYPERPARAMETERS["list"] + assert params["dict"] == EXPECTED_HYPERPARAMETERS["dict"] + + assert isinstance(params, dict) + assert isinstance(params["string"], str) + assert isinstance(params["integer"], int) + assert isinstance(params["boolean"], bool) + assert isinstance(params["float"], float) + assert isinstance(params["list"], list) + assert isinstance(params["dict"], dict) + + params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"] + print(f"SM_TRAINING_ENV -> hyperparameters: {params}") + assert params["string"] == EXPECTED_HYPERPARAMETERS["string"] + assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"] + assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"] + assert params["float"] == EXPECTED_HYPERPARAMETERS["float"] + assert params["list"] == EXPECTED_HYPERPARAMETERS["list"] + assert params["dict"] == EXPECTED_HYPERPARAMETERS["dict"] + + assert isinstance(params, dict) + assert isinstance(params["string"], str) + assert isinstance(params["integer"], int) + assert isinstance(params["boolean"], bool) + assert isinstance(params["float"], float) + assert isinstance(params["list"], list) + assert isinstance(params["dict"], dict) + + # Local JSON - DictConfig OmegaConf + params = OmegaConf.load("hyperparameters.json") + + print(f"Local hyperparameters.json: {params}") + assert params.string == EXPECTED_HYPERPARAMETERS["string"] + assert params.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert params.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert params.float == EXPECTED_HYPERPARAMETERS["float"] + assert params.list == EXPECTED_HYPERPARAMETERS["list"] + assert params.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert params.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert params.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert params.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert params.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert params.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert params.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + + @dataclass + class DictConfig: + string: str + integer: int + boolean: bool + float: float + list: List[int] + dict: Dict[str, Any] + + @dataclass + class HPConfig: + string: str + integer: int + boolean: bool + float: float + list: List[int] + dict: DictConfig + + # Local JSON - Structured OmegaConf + hp_config: HPConfig = OmegaConf.merge( + OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.json") + ) + print(f"Local hyperparameters.json - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + + # Local YAML - Structured OmegaConf + hp_config: HPConfig = OmegaConf.merge( + OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.yaml") + ) + print(f"Local hyperparameters.yaml - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + print(f"hyperparameters.yaml -> hyperparameters: {hp_config}") + + # HP Dict - Structured OmegaConf + hp_dict = json.loads(os.environ["SM_HPS"]) + hp_config: HPConfig = OmegaConf.merge(OmegaConf.structured(HPConfig), OmegaConf.create(hp_dict)) + print(f"SM_HPS - Structured: {hp_config}") + assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"] + assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"] + assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"] + assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"] + assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"] + assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"] + assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"] + assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"] + assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"] + assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"] + assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"] + assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"] + print(f"SM_HPS -> hyperparameters: {hp_config}") + + +if __name__ == "__main__": + main() diff --git a/tests/data/modules/params_script/train.sh b/tests/data/modules/params_script/train.sh new file mode 100644 index 0000000000..20f9a3c57a --- /dev/null +++ b/tests/data/modules/params_script/train.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +echo "Do some extra work here..." + +CMD="python train.py $@" +echo "Executing Command: $CMD" + +python train.py "$@" + +echo "Done!" diff --git a/tests/data/modules/script_mode/code.tar.gz b/tests/data/modules/script_mode/code.tar.gz new file mode 100644 index 0000000000..e2ed9d4b18 Binary files /dev/null and b/tests/data/modules/script_mode/code.tar.gz differ diff --git a/tests/data/modules/script_mode/custom_script.py b/tests/data/modules/script_mode/custom_script.py new file mode 100644 index 0000000000..a57ddee743 --- /dev/null +++ b/tests/data/modules/script_mode/custom_script.py @@ -0,0 +1,191 @@ +# flake8: noqa +import argparse +import numpy as np +import os +import sys +import logging +import json +import shutil +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset +from pytorch_model_def import get_model + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) +current_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_train_data(train_dir): + """ + Get the training data and convert to tensors + """ + + x_train = np.load(os.path.join(train_dir, "x_train.npy")) + y_train = np.load(os.path.join(train_dir, "y_train.npy")) + logger.info(f"x train: {x_train.shape}, y train: {y_train.shape}") + + return torch.from_numpy(x_train), torch.from_numpy(y_train) + + +def get_test_data(test_dir): + """ + Get the testing data and convert to tensors + """ + + x_test = np.load(os.path.join(test_dir, "x_test.npy")) + y_test = np.load(os.path.join(test_dir, "y_test.npy")) + logger.info(f"x test: {x_test.shape}, y test: {y_test.shape}") + + return torch.from_numpy(x_test), torch.from_numpy(y_test) + + +def model_fn(model_dir): + """ + Load the model for inference + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model() + model.load_state_dict(torch.load(model_dir + "/model.pth")) + model.eval() + return model.to(device) + + +def input_fn(request_body, request_content_type): + """ + Deserialize and prepare the prediction input + """ + + if request_content_type == "application/json": + request = json.loads(request_body) + train_inputs = torch.tensor(request) + return train_inputs + + +def predict_fn(input_data, model): + """ + Apply model to the incoming request + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + with torch.no_grad(): + return model(input_data.float()).numpy()[0] + + +def parse_args(): + """ + Parse the command line arguments + """ + + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-dir", + type=str, + default=os.environ.get("SM_MODEL_DIR", os.path.join(current_dir, "data/model")), + help="Directory to save the model", + ) + parser.add_argument( + "--train-dir", + type=str, + default=os.environ.get("SM_CHANNEL_TRAIN", os.path.join(current_dir, "data/train")), + help="Directory containing training data", + ) + parser.add_argument( + "--test-dir", + type=str, + default=os.environ.get("SM_CHANNEL_TEST", os.path.join(current_dir, "data/test")), + help="Directory containing testing data", + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + help="Batch size for training", + ) + parser.add_argument( + "--epochs", + type=int, + default=1, + help="Number of epochs for training", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=0.1, + help="Learning rate for training", + ) + return parser.parse_args() + + +def train(): + """ + Train the PyTorch model + """ + args = parse_args() + # Directories: train, test and model + train_dir = args.train_dir + test_dir = args.test_dir + model_dir = args.model_dir + + # Load the training and testing data + x_train, y_train = get_train_data(train_dir) + x_test, y_test = get_test_data(test_dir) + train_ds = TensorDataset(x_train, y_train) + + # Training parameters - used to configure the training loop + batch_size = args.batch_size + epochs = args.epochs + learning_rate = args.learning_rate + logger.info( + "batch_size = {}, epochs = {}, learning rate = {}".format(batch_size, epochs, learning_rate) + ) + + train_dl = DataLoader(train_ds, batch_size, shuffle=True) + + # Define the model, loss function and optimizer + model = get_model() + model = model.to(device) + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + + # Train the model + for epoch in range(epochs): + for x_train_batch, y_train_batch in train_dl: + y = model(x_train_batch.float()) + loss = criterion(y.flatten(), y_train_batch.float()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + epoch += 1 + logger.info(f"epoch: {epoch} -> loss: {loss}") + + # Test the model + with torch.no_grad(): + y = model(x_test.float()).flatten() + mse = ((y - y_test) ** 2).sum() / y_test.shape[0] + print("\nTest MSE:", mse.numpy()) + + # Save the model + os.makedirs(model_dir, exist_ok=True) + torch.save(model.state_dict(), model_dir + "/model.pth") + inference_code_path = model_dir + "/code/" + + if not os.path.exists(inference_code_path): + os.mkdir(inference_code_path) + logger.info("Created a folder at {}!".format(inference_code_path)) + + shutil.copy("custom_script.py", inference_code_path) + shutil.copy("pytorch_model_def.py", inference_code_path) + logger.info("Saving models files to {}".format(inference_code_path)) + + +if __name__ == "__main__": + print("Running the training job ...\n") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + train() diff --git a/tests/data/modules/script_mode/data/test/x_test.npy b/tests/data/modules/script_mode/data/test/x_test.npy new file mode 100644 index 0000000000..a9977e39c0 Binary files /dev/null and b/tests/data/modules/script_mode/data/test/x_test.npy differ diff --git a/tests/data/modules/script_mode/data/test/y_test.npy b/tests/data/modules/script_mode/data/test/y_test.npy new file mode 100644 index 0000000000..a7191945ee Binary files /dev/null and b/tests/data/modules/script_mode/data/test/y_test.npy differ diff --git a/tests/data/modules/script_mode/data/train/x_train.npy b/tests/data/modules/script_mode/data/train/x_train.npy new file mode 100644 index 0000000000..d267502e65 Binary files /dev/null and b/tests/data/modules/script_mode/data/train/x_train.npy differ diff --git a/tests/data/modules/script_mode/data/train/y_train.npy b/tests/data/modules/script_mode/data/train/y_train.npy new file mode 100644 index 0000000000..b8c17c4972 Binary files /dev/null and b/tests/data/modules/script_mode/data/train/y_train.npy differ diff --git a/tests/data/modules/script_mode/pytorch_model_def.py b/tests/data/modules/script_mode/pytorch_model_def.py new file mode 100644 index 0000000000..2440b22f88 --- /dev/null +++ b/tests/data/modules/script_mode/pytorch_model_def.py @@ -0,0 +1,23 @@ +# flake8: noqa +import torch +import torch.nn as nn + + +class NeuralNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(8, 8) + self.fc2 = nn.Linear(8, 6) + self.fc3 = nn.Linear(6, 1) + + def forward(self, x): + x = torch.tanh(self.fc1(x)) + x = torch.sigmoid(self.fc2(x)) + x = self.fc3(x) + return x + + +def get_model(): + + model = NeuralNet() + return model diff --git a/tests/data/modules/script_mode/requirements.txt b/tests/data/modules/script_mode/requirements.txt new file mode 100644 index 0000000000..f7b8ccf0cc --- /dev/null +++ b/tests/data/modules/script_mode/requirements.txt @@ -0,0 +1,3 @@ +numpy +-f https://download.pytorch.org/whl/torch_stable.html +torch==2.7.0 diff --git a/tests/data/modules/scripts/entry_script.py b/tests/data/modules/scripts/entry_script.py new file mode 100644 index 0000000000..3c972bd956 --- /dev/null +++ b/tests/data/modules/scripts/entry_script.py @@ -0,0 +1,19 @@ +import json +import os +import time + + +def main(): + hps = json.loads(os.environ["SM_HPS"]) + assert hps != None + print(f"Hyperparameters: {hps}") + + print("Running pseudo training script") + for epochs in range(hps["epochs"]): + print(f"Epoch: {epochs}") + time.sleep(1) + print("Finished running pseudo training script") + + +if __name__ == "__main__": + main() diff --git a/tests/data/pipeline/model_step/pytorch_mnist/mnist.py b/tests/data/pipeline/model_step/pytorch_mnist/mnist.py index ef1c15ae60..d29e80b399 100644 --- a/tests/data/pipeline/model_step/pytorch_mnist/mnist.py +++ b/tests/data/pipeline/model_step/pytorch_mnist/mnist.py @@ -57,7 +57,7 @@ def _get_train_data_loader(training_dir, is_distributed, batch_size, **kwargs): batch_size=batch_size, shuffle=train_sampler is None, sampler=train_sampler, - **kwargs + **kwargs, ) return train_sampler, train_loader @@ -75,7 +75,7 @@ def _get_test_data_loader(training_dir, **kwargs): ), batch_size=1000, shuffle=True, - **kwargs + **kwargs, ) diff --git a/tests/data/pipeline/model_step/pytorch_mnist/requirements.txt b/tests/data/pipeline/model_step/pytorch_mnist/requirements.txt index 56d09228be..c25fca7e9f 100644 --- a/tests/data/pipeline/model_step/pytorch_mnist/requirements.txt +++ b/tests/data/pipeline/model_step/pytorch_mnist/requirements.txt @@ -1 +1 @@ -scipy>=1.8.1 +scipy>=1.11.3 diff --git a/tests/data/pytorch_mnist/mnist.py b/tests/data/pytorch_mnist/mnist.py index ef1c15ae60..d29e80b399 100644 --- a/tests/data/pytorch_mnist/mnist.py +++ b/tests/data/pytorch_mnist/mnist.py @@ -57,7 +57,7 @@ def _get_train_data_loader(training_dir, is_distributed, batch_size, **kwargs): batch_size=batch_size, shuffle=train_sampler is None, sampler=train_sampler, - **kwargs + **kwargs, ) return train_sampler, train_loader @@ -75,7 +75,7 @@ def _get_test_data_loader(training_dir, **kwargs): ), batch_size=1000, shuffle=True, - **kwargs + **kwargs, ) diff --git a/tests/data/remote_function/requirements.txt b/tests/data/remote_function/requirements.txt index 0e99587e6e..44ce1d9331 100644 --- a/tests/data/remote_function/requirements.txt +++ b/tests/data/remote_function/requirements.txt @@ -1 +1 @@ -scipy==1.10.1 +scipy==1.11.3 diff --git a/tests/data/serve_resources/mlflow/pytorch/conda.yaml b/tests/data/serve_resources/mlflow/pytorch/conda.yaml index be61456197..b740d25b70 100644 --- a/tests/data/serve_resources/mlflow/pytorch/conda.yaml +++ b/tests/data/serve_resources/mlflow/pytorch/conda.yaml @@ -9,7 +9,7 @@ dependencies: - cffi==1.16.0 - cloudpickle==2.2.1 - defusedxml==0.7.1 - - dill==0.3.8 + - dill==0.3.9 - gmpy2==2.1.2 - numpy==1.26.4 - opt-einsum==3.3.0 @@ -17,8 +17,8 @@ dependencies: - pandas==2.2.1 - pyyaml==6.0.1 - requests==2.31.0 - - torch==2.0.1 - - torchvision==0.15.2 + - torch>=2.6.0 + - torchvision>=0.17.0 - tqdm==4.66.2 - scikit-learn==1.3.2 name: mlflow-env diff --git a/tests/data/serve_resources/mlflow/pytorch/requirements.txt b/tests/data/serve_resources/mlflow/pytorch/requirements.txt index 9848949b0f..eabe5e8e82 100644 --- a/tests/data/serve_resources/mlflow/pytorch/requirements.txt +++ b/tests/data/serve_resources/mlflow/pytorch/requirements.txt @@ -1,16 +1,16 @@ -mlflow==2.10.2 +mlflow==2.20.3 astunparse==1.6.3 cffi==1.16.0 cloudpickle==2.2.1 defusedxml==0.7.1 -dill==0.3.8 +dill==0.3.9 gmpy2==2.1.2 -numpy==1.24.4 +numpy==1.26.4 opt-einsum==3.3.0 -packaging==21.3 +packaging>=23.0,<25 pandas==2.2.1 pyyaml==6.0.1 -requests==2.31.0 -torch==2.0.1 -torchvision==0.15.2 -tqdm==4.66.2 +requests==2.32.4 +torch>=2.6.0 +torchvision>=0.17.0 +tqdm==4.66.3 diff --git a/tests/data/serve_resources/mlflow/tensorflow/MLmodel b/tests/data/serve_resources/mlflow/tensorflow/MLmodel new file mode 100644 index 0000000000..f00412149d --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/MLmodel @@ -0,0 +1,17 @@ +artifact_path: model +flavors: + python_function: + env: + conda: conda.yaml + virtualenv: python_env.yaml + loader_module: mlflow.tensorflow + python_version: 3.10.13 + tensorflow: + code: null + model_type: tf2-module + saved_model_dir: tf2model +mlflow_version: 2.11.1 +model_size_bytes: 23823 +model_uuid: 40d2323944294fce898d8693455f60e8 +run_id: 592132312fb84935b201de2c027c54c6 +utc_time_created: '2024-04-01 19:47:15.396517' diff --git a/tests/data/serve_resources/mlflow/tensorflow/conda.yaml b/tests/data/serve_resources/mlflow/tensorflow/conda.yaml new file mode 100644 index 0000000000..90d8c300a0 --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/conda.yaml @@ -0,0 +1,11 @@ +channels: +- conda-forge +dependencies: +- python=3.10.13 +- pip<=23.3.1 +- pip: + - mlflow==2.11.1 + - cloudpickle==2.2.1 + - numpy==1.26.4 + - tensorflow==2.16.1 +name: mlflow-env diff --git a/tests/data/serve_resources/mlflow/tensorflow/python_env.yaml b/tests/data/serve_resources/mlflow/tensorflow/python_env.yaml new file mode 100644 index 0000000000..9e09178b6c --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/python_env.yaml @@ -0,0 +1,7 @@ +python: 3.10.13 +build_dependencies: +- pip==23.3.1 +- setuptools==68.2.2 +- wheel==0.41.2 +dependencies: +- -r requirements.txt diff --git a/tests/data/serve_resources/mlflow/tensorflow/registered_model_meta b/tests/data/serve_resources/mlflow/tensorflow/registered_model_meta new file mode 100644 index 0000000000..5423c0e6c7 --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/registered_model_meta @@ -0,0 +1,2 @@ +model_name: model +model_version: '2' diff --git a/tests/data/serve_resources/mlflow/tensorflow/requirements.txt b/tests/data/serve_resources/mlflow/tensorflow/requirements.txt new file mode 100644 index 0000000000..9b64992ac8 --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/requirements.txt @@ -0,0 +1,4 @@ +mlflow==2.20.3 +cloudpickle==2.2.1 +numpy==1.26.4 +tensorflow==2.16.1 diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/fingerprint.pb b/tests/data/serve_resources/mlflow/tensorflow/tf2model/fingerprint.pb new file mode 100644 index 0000000000..ba1e240ba5 --- /dev/null +++ b/tests/data/serve_resources/mlflow/tensorflow/tf2model/fingerprint.pb @@ -0,0 +1 @@ +ďn/ ʢ(32 \ No newline at end of file diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/saved_model.pb b/tests/data/serve_resources/mlflow/tensorflow/tf2model/saved_model.pb new file mode 100644 index 0000000000..e48f2b59cc Binary files /dev/null and b/tests/data/serve_resources/mlflow/tensorflow/tf2model/saved_model.pb differ diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.data-00000-of-00001 b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000..575da96282 Binary files /dev/null and b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.data-00000-of-00001 differ diff --git a/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.index b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.index new file mode 100644 index 0000000000..57646ac350 Binary files /dev/null and b/tests/data/serve_resources/mlflow/tensorflow/tf2model/variables/variables.index differ diff --git a/tests/data/serve_resources/mlflow/xgboost/requirements.txt b/tests/data/serve_resources/mlflow/xgboost/requirements.txt index 8150c9fedf..78c7a1afda 100644 --- a/tests/data/serve_resources/mlflow/xgboost/requirements.txt +++ b/tests/data/serve_resources/mlflow/xgboost/requirements.txt @@ -1,8 +1,8 @@ -mlflow==2.11.1 +mlflow==3.1.0 lz4==4.3.2 -numpy==1.24.4 +numpy==1.26.4 pandas==2.0.3 psutil==5.9.8 -scikit-learn==1.3.2 -scipy==1.10.1 +scikit-learn==1.5.1 +scipy==1.11.3 xgboost==1.7.1 diff --git a/tests/data/tensorflow_mnist/mnist_v2.py b/tests/data/tensorflow_mnist/mnist_v2.py index 05467dee49..9efb282c49 100644 --- a/tests/data/tensorflow_mnist/mnist_v2.py +++ b/tests/data/tensorflow_mnist/mnist_v2.py @@ -198,7 +198,10 @@ def main(args): if args.current_host == args.hosts[0]: ckpt_manager.save() - net.save("/opt/ml/model/1") + if int(tf_major) > 2 or (int(tf_major) == 2 and int(tf_minor) >= 16): + net.export("/opt/ml/model/1") + else: + net.save("/opt/ml/model/1") if __name__ == "__main__": diff --git a/tests/data/workflow/requirements.txt b/tests/data/workflow/requirements.txt index 0e99587e6e..44ce1d9331 100644 --- a/tests/data/workflow/requirements.txt +++ b/tests/data/workflow/requirements.txt @@ -1 +1 @@ -scipy==1.10.1 +scipy==1.11.3 diff --git a/tests/integ/__init__.py b/tests/integ/__init__.py index 434f4dd744..a01223b256 100644 --- a/tests/integ/__init__.py +++ b/tests/integ/__init__.py @@ -87,6 +87,7 @@ "ap-south-1", "ap-northeast-2", # it has p3, but not enough "us-east-2", # it has p3, but not enough + "eu-west-1", # it has p3, but not enough ] # EI is currently only supported in the following regions diff --git a/tests/integ/sagemaker/aws_batch/__init__.py b/tests/integ/sagemaker/aws_batch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/aws_batch/manager.py b/tests/integ/sagemaker/aws_batch/manager.py new file mode 100644 index 0000000000..b417f86b53 --- /dev/null +++ b/tests/integ/sagemaker/aws_batch/manager.py @@ -0,0 +1,133 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import time + + +class BatchTestResourceManager: + + def __init__( + self, + batch_client, + queue_name="pysdk-test-queue", + service_env_name="pysdk-test-queue-service-environment", + ): + self.batch_client = batch_client + self.queue_name = queue_name + self.service_environment_name = service_env_name + + def _create_or_get_service_environment(self, service_environment_name): + print(f"Creating service environment: {service_environment_name}") + try: + response = self.batch_client.create_service_environment( + serviceEnvironmentName=service_environment_name, + serviceEnvironmentType="SAGEMAKER_TRAINING", + capacityLimits=[{"maxCapacity": 10, "capacityUnit": "NUM_INSTANCES"}], + ) + print(f"Service environment {service_environment_name} created successfully.") + return response + except Exception as e: + if "Object already exists" in str(e): + print("Resource already exists. Fetching existing resource.") + response = self.batch_client.describe_service_environments( + serviceEnvironments=[service_environment_name] + ) + return response["serviceEnvironments"][0] + else: + print(f"Error creating service environment: {e}") + raise + + def _create_or_get_queue(self, queue_name, service_environment_arn): + + print(f"Creating job queue: {queue_name}") + try: + response = self.batch_client.create_job_queue( + jobQueueName=queue_name, + priority=1, + computeEnvironmentOrder=[], + serviceEnvironmentOrder=[ + { + "order": 1, + "serviceEnvironment": service_environment_arn, + }, + ], + jobQueueType="SAGEMAKER_TRAINING", + ) + print(f"Job queue {queue_name} created successfully.") + return response + except Exception as e: + if "Object already exists" in str(e): + print("Resource already exists. Fetching existing resource.") + response = self.batch_client.describe_job_queues(jobQueues=[queue_name]) + return response["jobQueues"][0] + else: + print(f"Error creating job queue: {e}") + raise + + def _update_queue_state(self, queue_name, state): + try: + print(f"Updating queue {queue_name} to state {state}") + response = self.batch_client.update_job_queue(jobQueue=queue_name, state=state) + return response + except Exception as e: + print(f"Error updating queue: {e}") + + def _update_service_environment_state(self, service_environment_name, state): + print(f"Updating service environment {service_environment_name} to state {state}") + try: + response = self.batch_client.update_service_environment( + serviceEnvironment=service_environment_name, state=state + ) + return response + except Exception as e: + print(f"Error updating service environment: {e}") + + def _wait_for_queue_state(self, queue_name, state): + print(f"Waiting for queue {queue_name} to be {state}...") + while True: + response = self.batch_client.describe_job_queues(jobQueues=[queue_name]) + print(f"Current state: {response}") + if response["jobQueues"][0]["state"] == state: + break + time.sleep(5) + print(f"Queue {queue_name} is now {state}.") + + def _wait_for_service_environment_state(self, service_environment_name, state): + print(f"Waiting for service environment {service_environment_name} to be {state}...") + while True: + response = self.batch_client.describe_service_environments( + serviceEnvironments=[service_environment_name] + ) + print(f"Current state: {response}") + if response["serviceEnvironments"][0]["state"] == state: + break + time.sleep(5) + print(f"Service environment {service_environment_name} is now {state}.") + + def get_or_create_resources(self, queue_name=None, service_environment_name=None): + queue_name = queue_name or self.queue_name + service_environment_name = service_environment_name or self.service_environment_name + + service_environment = self._create_or_get_service_environment(service_environment_name) + if service_environment.get("state") != "ENABLED": + self._update_service_environment_state(service_environment_name, "ENABLED") + self._wait_for_service_environment_state(service_environment_name, "ENABLED") + time.sleep(10) + + queue = self._create_or_get_queue(queue_name, service_environment["serviceEnvironmentArn"]) + if queue.get("state") != "ENABLED": + self._update_queue_state(queue_name, "ENABLED") + self._wait_for_queue_state(queue_name, "ENABLED") + time.sleep(10) + return queue, service_environment diff --git a/tests/integ/sagemaker/aws_batch/test_queue.py b/tests/integ/sagemaker/aws_batch/test_queue.py new file mode 100644 index 0000000000..20b8de55c1 --- /dev/null +++ b/tests/integ/sagemaker/aws_batch/test_queue.py @@ -0,0 +1,93 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import boto3 +import botocore +import pytest + +from sagemaker.modules.train import ModelTrainer +from sagemaker.modules.configs import SourceCode, InputData, Compute + +from sagemaker.aws_batch.training_queue import TrainingQueue + +from tests.integ import DATA_DIR +from tests.integ.sagemaker.modules.conftest import modules_sagemaker_session # noqa: F401 +from tests.integ.sagemaker.modules.train.test_model_trainer import ( + DEFAULT_CPU_IMAGE, +) +from tests.integ.sagemaker.aws_batch.manager import BatchTestResourceManager + + +@pytest.fixture(scope="module") +def batch_client(): + return boto3.client("batch", region_name="us-west-2") + + +@pytest.fixture(scope="function") +def batch_test_resource_manager(batch_client): + resource_manager = BatchTestResourceManager(batch_client=batch_client) + resource_manager.get_or_create_resources() + return resource_manager + + +def test_model_trainer_submit(batch_test_resource_manager, modules_sagemaker_session): # noqa: F811 + queue_name = batch_test_resource_manager.queue_name + + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/script_mode/", + requirements="requirements.txt", + entry_script="custom_script.py", + ) + hyperparameters = { + "batch-size": 32, + "epochs": 1, + "learning-rate": 0.01, + } + compute = Compute(instance_type="ml.m5.2xlarge") + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + source_code=source_code, + compute=compute, + hyperparameters=hyperparameters, + base_job_name="test-batch-model-trainer", + ) + train_data = InputData( + channel_name="train", + data_source=f"{DATA_DIR}/modules/script_mode/data/train/", + ) + test_data = InputData( + channel_name="test", + data_source=f"{DATA_DIR}/modules/script_mode/data/test/", + ) + + training_queue = TrainingQueue(queue_name=queue_name) + + try: + queued_job = training_queue.submit( + training_job=model_trainer, + inputs=[train_data, test_data], + ) + except botocore.exceptions.ClientError as e: + print(e.response["ResponseMetadata"]) + print(e.response["Error"]["Message"]) + raise e + res = queued_job.describe() + assert res is not None + assert res["status"] == "SUBMITTED" + + queued_job.wait(timeout=1800) + res = queued_job.describe() + assert res is not None + assert res["status"] == "SUCCEEDED" diff --git a/tests/integ/sagemaker/conftest.py b/tests/integ/sagemaker/conftest.py index 2dc9f7df4d..421ef10b1d 100644 --- a/tests/integ/sagemaker/conftest.py +++ b/tests/integ/sagemaker/conftest.py @@ -14,16 +14,16 @@ import base64 import os -import subprocess -import shutil -import pytest -import docker import re +import shutil +import subprocess import sys +import docker +import pytest from docker.errors import BuildError -from sagemaker.utils import sagemaker_timestamp, _tmpdir, sts_regional_endpoint +from sagemaker.utils import _tmpdir, sagemaker_timestamp, sts_regional_endpoint REPO_ACCOUNT_ID = "033110030271" @@ -46,8 +46,8 @@ 'SHELL ["/bin/bash", "-c"]\n' "RUN apt-get update -y \ && apt-get install -y unzip curl\n\n" - "RUN curl -L -O 'https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh' \ - && bash Mambaforge-Linux-x86_64.sh -b -p '/opt/conda' \ + "RUN curl -L -O 'https://github.com/conda-forge/miniforge/releases/download/24.11.3-2/Miniforge3-Linux-x86_64.sh' \ + && bash Miniforge3-Linux-x86_64.sh -b -p '/opt/conda' \ && /opt/conda/bin/conda init bash\n\n" "ENV PATH $PATH:/opt/conda/bin\n" "RUN mamba create -n integ_test_env python={py_version} -y \ @@ -68,7 +68,7 @@ "RUN curl 'https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip' -o 'awscliv2.zip' \ && unzip awscliv2.zip \ && ./aws/install\n\n" - "RUN apt install sudo\n" + "RUN apt install -y sudo\n" "RUN useradd -ms /bin/bash integ-test-user\n" # Add the user to sudo group "RUN usermod -aG sudo integ-test-user\n" @@ -86,8 +86,8 @@ 'SHELL ["/bin/bash", "-c"]\n' "RUN apt-get update -y \ && apt-get install -y unzip curl\n\n" - "RUN curl -L -O 'https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh' \ - && bash Mambaforge-Linux-x86_64.sh -b -p '/opt/conda' \ + "RUN curl -L -O 'https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh' \ + && bash Miniforge3-Miniforge3-Linux-x86_64.sh -b -p '/opt/conda' \ && /opt/conda/bin/conda init bash\n\n" "ENV PATH $PATH:/opt/conda/bin\n" "COPY {source_archive} ./\n" @@ -102,7 +102,9 @@ "channels:\n" " - defaults\n" "dependencies:\n" - " - scipy=1.10.1\n" + " - requests=2.32.3\n" + " - charset-normalizer=3.3.2\n" + " - scipy=1.13.1\n" " - pip:\n" " - /sagemaker-{sagemaker_version}.tar.gz\n" "prefix: /opt/conda/bin/conda\n" @@ -176,7 +178,7 @@ def conda_env_yml(): os.remove(conda_yml_file_name) -def _build_container(sagemaker_session, py_version, docker_templete): +def _build_container(sagemaker_session, py_version, docker_template): """Build a dummy test container locally and push a container to an ecr repo""" region = sagemaker_session.boto_region_name @@ -189,9 +191,9 @@ def _build_container(sagemaker_session, py_version, docker_templete): print("building source archive...") source_archive = _generate_sagemaker_sdk_tar(tmpdir) with open(os.path.join(tmpdir, "Dockerfile"), "w") as file: - file.writelines( - docker_templete.format(py_version=py_version, source_archive=source_archive) - ) + content = docker_template.format(py_version=py_version, source_archive=source_archive) + print(f"Dockerfile contents: \n{content}\n") + file.writelines(content) docker_client = docker.from_env() @@ -209,6 +211,7 @@ def _build_container(sagemaker_session, py_version, docker_templete): raise if _is_repository_exists(ecr_client, REPO_NAME): + print("pushing to session configured account id!") sts_client = sagemaker_session.boto_session.client( "sts", region_name=region, endpoint_url=sts_regional_endpoint(region) ) @@ -218,6 +221,7 @@ def _build_container(sagemaker_session, py_version, docker_templete): account_id, sagemaker_session.boto_region_name, REPO_NAME, image_tag ) else: + print(f"pushing to account id: {REPO_ACCOUNT_ID}") ecr_image = _ecr_image_uri( REPO_ACCOUNT_ID, sagemaker_session.boto_region_name, @@ -232,7 +236,7 @@ def _build_container(sagemaker_session, py_version, docker_templete): return ecr_image -def _build_auto_capture_client_container(py_version, docker_templete): +def _build_auto_capture_client_container(py_version, docker_template): """Build a test docker container that will act as a client for auto_capture tests""" with _tmpdir() as tmpdir: print("building docker image locally in ", tmpdir) @@ -240,9 +244,9 @@ def _build_auto_capture_client_container(py_version, docker_templete): source_archive = _generate_sdk_tar_with_public_version(tmpdir) _move_auto_capture_test_file(tmpdir) with open(os.path.join(tmpdir, "Dockerfile"), "w") as file: - file.writelines( - docker_templete.format(py_version=py_version, source_archive=source_archive) - ) + content = docker_template.format(py_version=py_version, source_archive=source_archive) + print(f"Dockerfile contents: \n{content}\n") + file.writelines(content) docker_client = docker.from_env() @@ -276,11 +280,14 @@ def _generate_sagemaker_sdk_tar(destination_folder): """ Run setup.py sdist to generate the PySDK tar file """ - subprocess.run( - f"python3 setup.py egg_info --egg-base {destination_folder} sdist -d {destination_folder} -k", - shell=True, - check=True, - ) + command = f"python -m build --sdist -o {destination_folder}" + print(f"Running command: {command}") + result = subprocess.run(command, shell=True, check=True, capture_output=True) + if result.returncode != 0: + print(f"Command failed with return code: {result.returncode}") + + print(f"Standard output: {result.stdout.decode()}") + print(f"Standard error: {result.stderr.decode()}") destination_folder_contents = os.listdir(destination_folder) source_archive = [file for file in destination_folder_contents if file.endswith("tar.gz")][0] diff --git a/tests/integ/sagemaker/experiments/conftest.py b/tests/integ/sagemaker/experiments/conftest.py index 693e147392..1f4ae26247 100644 --- a/tests/integ/sagemaker/experiments/conftest.py +++ b/tests/integ/sagemaker/experiments/conftest.py @@ -155,7 +155,7 @@ def tempdir(): @pytest.fixture(scope="module") def dev_sdk_tar(): resource_dir = os.path.join(DATA_DIR, "experiment") - os.system("python setup.py sdist -k") + os.system("python -m build --sdist") sdist_path = max(glob.glob("dist/sagemaker-*"), key=os.path.getctime) sdk_file = os.path.join(resource_dir, _EXP_PLUS_SDK_TAR) shutil.copy(sdist_path, sdk_file) diff --git a/tests/integ/sagemaker/experiments/helpers.py b/tests/integ/sagemaker/experiments/helpers.py index 9a22c3a30c..c8f35471b1 100644 --- a/tests/integ/sagemaker/experiments/helpers.py +++ b/tests/integ/sagemaker/experiments/helpers.py @@ -13,9 +13,12 @@ from __future__ import absolute_import from contextlib import contextmanager +import pytest +import logging from sagemaker import utils from sagemaker.experiments.experiment import Experiment +from sagemaker.experiments._run_context import _RunContext EXP_INTEG_TEST_NAME_PREFIX = "experiments-integ" @@ -40,3 +43,17 @@ def cleanup_exp_resources(exp_names, sagemaker_session): for exp_name in exp_names: exp = Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session) exp._delete_all(action="--force") + + +@pytest.fixture +def clear_run_context(): + current_run = _RunContext.get_current_run() + if current_run is None: + return + + logging.info( + f"RunContext already populated by run {current_run.run_name}" + f" in experiment {current_run.experiment_name}." + " Clearing context manually" + ) + _RunContext.drop_current_run() diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py index 4f59d11c54..f00f53a5ad 100644 --- a/tests/integ/sagemaker/experiments/test_run.py +++ b/tests/integ/sagemaker/experiments/test_run.py @@ -720,8 +720,8 @@ def _generate_processor( ) return FrameworkProcessor( estimator_cls=PyTorch, - framework_version="1.10", - py_version="py38", + framework_version="1.13.1", + py_version="py39", instance_count=1, instance_type="ml.m5.xlarge", role=execution_role, diff --git a/tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor.py b/tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor_integ.py similarity index 99% rename from tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor.py rename to tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor_integ.py index 0d1ada759d..fb69bb1b3f 100644 --- a/tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor.py +++ b/tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor_integ.py @@ -1272,7 +1272,7 @@ def _generate_and_move_sagemaker_sdk_tar(): """ Run setup.py sdist to generate the PySDK whl file """ - subprocess.run("python3 setup.py bdist_wheel", shell=True) + subprocess.run("python -m build --wheel", shell=True) dist_dir = "dist" source_archive = os.listdir(dist_dir)[0] source_path = os.path.join(dist_dir, source_archive) diff --git a/tests/integ/sagemaker/jumpstart/conftest.py b/tests/integ/sagemaker/jumpstart/conftest.py index c7554f3e51..260b0f2b22 100644 --- a/tests/integ/sagemaker/jumpstart/conftest.py +++ b/tests/integ/sagemaker/jumpstart/conftest.py @@ -16,24 +16,43 @@ import boto3 import pytest from botocore.config import Config +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.hub.hub import Hub from sagemaker.session import Session from tests.integ.sagemaker.jumpstart.constants import ( ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, + ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, + HUB_NAME_PREFIX, JUMPSTART_TAG, ) +from sagemaker.jumpstart.types import ( + HubContentType, +) + from tests.integ.sagemaker.jumpstart.utils import ( get_test_artifact_bucket, get_test_suite_id, + get_sm_session, + with_exponential_backoff, ) -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME - def _setup(): print("Setting up...") - os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: get_test_suite_id()}) + test_suite_id = get_test_suite_id() + test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}" + test_hub_description = "PySDK Integ Test Private Hub" + + os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: test_suite_id}) + os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: test_hub_name}) + + # Create a private hub to use for the test session + hub = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + hub.create(description=test_hub_description) def _teardown(): @@ -43,6 +62,8 @@ def _teardown(): test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID] + test_hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] + boto3_session = boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME) sagemaker_client = boto3_session.client( @@ -113,6 +134,29 @@ def _teardown(): bucket = s3_resource.Bucket(test_cache_bucket) bucket.objects.filter(Prefix=test_suite_id + "/").delete() + # delete private hubs + _delete_hubs(sagemaker_session, test_hub_name) + + +def _delete_hubs(sagemaker_session, hub_name): + # list and delete all hub contents first + list_hub_content_response = sagemaker_session.list_hub_contents( + hub_name=hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value + ) + for model in list_hub_content_response["HubContentSummaries"]: + _delete_hub_contents(sagemaker_session, hub_name, model) + + sagemaker_session.delete_hub(hub_name) + + +@with_exponential_backoff() +def _delete_hub_contents(sagemaker_session, hub_name, model): + sagemaker_session.delete_hub_content_reference( + hub_name=hub_name, + hub_content_type=HubContentType.MODEL_REFERENCE.value, + hub_content_name=model["HubContentName"], + ) + @pytest.fixture(scope="session", autouse=True) def setup(request): diff --git a/tests/integ/sagemaker/jumpstart/constants.py b/tests/integ/sagemaker/jumpstart/constants.py index f5ffbf7a3a..740d88e9c0 100644 --- a/tests/integ/sagemaker/jumpstart/constants.py +++ b/tests/integ/sagemaker/jumpstart/constants.py @@ -37,17 +37,22 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str: ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID = "JUMPSTART_SDK_TEST_SUITE_ID" +ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME = "JUMPSTART_SDK_TEST_HUB_NAME" + JUMPSTART_TAG = "JumpStart-SDK-Integ-Test-Suite-Id" +HUB_NAME_PREFIX = "PySDK-HubTest-" TRAINING_DATASET_MODEL_DICT = { ("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"), ("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"), - ("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"), + ("huggingface-spc-bert-base-cased", "2.0.3"): ("training-datasets/QNLI-tiny/"), + ("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI/"), ("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"), ("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"), ("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"), ("meta-textgeneration-llama-2-7b", "3.*"): ("training-datasets/sec_amazon/"), + ("meta-textgeneration-llama-2-7b", "4.*"): ("training-datasets/sec_amazon/"), ("meta-textgenerationneuron-llama-2-7b", "*"): ("training-datasets/sec_amazon/"), } diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py index a839a293c5..00c87fac1b 100644 --- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py +++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py @@ -140,7 +140,7 @@ def test_gated_model_training_v1(setup): def test_gated_model_training_v2(setup): model_id = "meta-textgeneration-llama-2-7b" - model_version = "3.*" # model artifacts retrieved from jumpstart-private-cache-* buckets + model_version = "4.*" # model artifacts retrieved from jumpstart-private-cache-* buckets estimator = JumpStartEstimator( model_id=model_id, @@ -150,6 +150,7 @@ def test_gated_model_training_v2(setup): tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], environment={"accept_eula": "true"}, max_run=259200, # avoid exceeding resource limits + tolerate_vulnerable_model=True, # tolerate old version of model ) # uses ml.g5.12xlarge instance @@ -173,6 +174,7 @@ def test_gated_model_training_v2(setup): tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), + instance_type="ml.g5.2xlarge", ) payload = { diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index 5205765e2f..c9a39ac3dc 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -11,7 +11,10 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import + +import io import os +import sys import time from unittest import mock @@ -50,6 +53,8 @@ "ap-southeast-2", } +TEST_HUB_WITH_REFERENCE = "mock-hub-name" + def test_non_prepacked_jumpstart_model(setup): @@ -165,7 +170,7 @@ def test_jumpstart_gated_model(setup): model = JumpStartModel( model_id=model_id, - model_version="3.*", # version >=3.0.0 stores artifacts in jumpstart-private-cache-* buckets + model_version="*", # version >=3.0.0 stores artifacts in jumpstart-private-cache-* buckets role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), ) @@ -192,7 +197,7 @@ def test_jumpstart_gated_model_inference_component_enabled(setup): model = JumpStartModel( model_id=model_id, - model_version="3.*", # version >=3.0.0 stores artifacts in jumpstart-private-cache-* buckets + model_version="*", # version >=3.0.0 stores artifacts in jumpstart-private-cache-* buckets role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), ) @@ -219,9 +224,14 @@ def test_jumpstart_gated_model_inference_component_enabled(setup): assert response is not None + model = JumpStartModel.attach(predictor.endpoint_name, sagemaker_session=get_sm_session()) + assert model.model_id == model_id + assert model.endpoint_name == predictor.endpoint_name + assert model.inference_component_name == predictor.component_name + @mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") -def test_instatiating_model(mock_warning_logger, setup): +def test_instantiating_model(mock_warning_logger, setup): model_id = "catboost-regression-model" @@ -260,6 +270,8 @@ def test_jumpstart_model_register(setup): response = predictor.predict("hello world!") + predictor.delete_predictor() + assert response is not None @@ -286,3 +298,271 @@ def test_proprietary_jumpstart_model(setup): response = predictor.predict(payload) assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable if test account is subscribed to the proprietary model", +) +def test_register_proprietary_jumpstart_model(setup): + + model_id = "ai21-jurassic-2-light" + + model = JumpStartModel( + model_id=model_id, + model_version="2.0.004", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + model_package = model.register() + + predictor = model_package.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}] + ) + payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} + + response = predictor.predict(payload) + + predictor.delete_predictor() + + assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable if test account is subscribed to the proprietary model", +) +def test_register_gated_jumpstart_model(setup): + + model_id = "meta-textgenerationneuron-llama-2-7b" + model = JumpStartModel( + model_id=model_id, + model_version="1.1.0", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + model_package = model.register(accept_eula=True) + + predictor = model_package.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + accept_eula=True, + ) + payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1} + + response = predictor.predict(payload) + + predictor.delete_predictor() + + assert response is not None + + +@pytest.mark.skipif( + True, + reason="Only enable after metadata is fully deployed.", +) +def test_jumpstart_model_with_deployment_configs(setup): + model_id = "meta-textgeneration-llama-2-13b" + + model = JumpStartModel( + model_id=model_id, + model_version="*", + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + captured_output = io.StringIO() + sys.stdout = captured_output + model.display_benchmark_metrics() + sys.stdout = sys.__stdout__ + assert captured_output.getvalue() is not None + + configs = model.list_deployment_configs() + assert len(configs) > 0 + + model.set_deployment_config( + configs[0]["ConfigName"], + "ml.g5.2xlarge", + ) + assert model.config_name == configs[0]["ConfigName"] + + predictor = model.deploy( + accept_eula=True, + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + payload = { + "inputs": "some-payload", + "parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, + } + + response = predictor.predict(payload, custom_attributes="accept_eula=true") + + assert response is not None + + +def test_jumpstart_session_with_config_name(): + model = JumpStartModel(model_id="meta-textgeneration-llama-2-7b") + assert model.config_name is not None + session = model.sagemaker_session + + # we're mocking the http request, so it's expected to raise an Exception. + # we're interested that the low-level request attaches the correct + # jumpstart-related tags. + with mock.patch("botocore.client.BaseClient._make_request") as mock_make_request: + try: + session.sagemaker_client.list_endpoints() + except Exception: + pass + + assert ( + "md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#tgi" + in mock_make_request.call_args[0][1]["headers"]["User-Agent"] + ) + + +def _setup_test_hub_with_reference(public_hub_model_id: str): + session = get_sm_session() + + try: + session.create_hub( + hub_name=TEST_HUB_WITH_REFERENCE, + hub_description="this is my sagemaker hub", + hub_display_name="Mock Hub", + hub_search_keywords=["mock", "hub", "123"], + s3_storage_config={"S3OutputPath": "s3://my-hub-bucket/"}, + tags=[{"Key": "tag-key-1", "Value": "tag-value-1"}], + ) + except Exception as e: + if "ResourceInUse" in str(e): + print("Hub already exists") + else: + raise e + + try: + session.create_hub_content_reference( + hub_name=TEST_HUB_WITH_REFERENCE, + source_hub_content_arn=( + f"arn:aws:sagemaker:{session.boto_region_name}:aws:" + f"hub-content/SageMakerPublicHub/Model/{public_hub_model_id}" + ), + ) + except Exception as e: + if "ResourceInUse" in str(e): + print("Reference already exists") + else: + raise e + + +def _teardown_test_hub_with_reference(public_hub_model_id: str): + session = get_sm_session() + + try: + session.delete_hub_content_reference( + hub_name=TEST_HUB_WITH_REFERENCE, + hub_content_type="ModelReference", + hub_content_name=public_hub_model_id, + ) + except Exception as e: + if "ResourceInUse" in str(e): + print("Reference already exists") + else: + raise e + + try: + session.delete_hub(hub_name=TEST_HUB_WITH_REFERENCE) + except Exception as e: + if "ResourceInUse" in str(e): + print("Hub already exists") + else: + raise e + + +@pytest.mark.skip +# Currently JumpStartModel does not pull from HubService for the Public Hub. +def test_model_reference_marketplace_model(setup): + session = get_sm_session() + + # TODO: hardcoded model ID is brittle - should be dynamic pull via ListHubContents + public_hub_marketplace_model_id = "upstage-solar-mini-chat" + _setup_test_hub_with_reference(public_hub_marketplace_model_id) + + JumpStartModel( # Retrieving MP model None -> defaults to latest SemVer + model_id=public_hub_marketplace_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + ) + + model_semver = JumpStartModel( # Retrieving MP model SemVer -> uses SemVer + model_id=public_hub_marketplace_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + model_version="1.0.0", + ) + + model_marketplace_version = JumpStartModel( # Retrieving MP model MP version -> uses MPver + model_id=public_hub_marketplace_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + model_version="240612.5", + ) + + _teardown_test_hub_with_reference(public_hub_marketplace_model_id) # Cleanup before assertions + + assert model_semver.model_version == model_marketplace_version.model_version + + +# TODO: PySDK test account not subscribed to this model +# def test_model_reference_marketplace_model_deployment(setup): +# session = get_sm_session() +# public_hub_marketplace_model_id = "upstage-solar-mini-chat" +# _setup_test_hub_with_reference(public_hub_marketplace_model_id) + +# marketplace_model = JumpStartModel( # Retrieving MP model MP version -> uses MPver +# model_id=public_hub_marketplace_model_id, +# hub_name=TEST_HUB_WITH_REFERENCE, +# role=session.get_caller_identity_arn(), +# sagemaker_session=session, +# model_version="240612.5", +# ) +# predictor = marketplace_model.deploy( +# tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], +# accept_eula=True, +# ) + +# predictor.delete_predictor() +# _teardown_test_hub_with_reference(public_hub_marketplace_model_id) + + +@pytest.mark.skip +def test_bedrock_store_model_tags_from_hub_service(setup): + + session = get_sm_session() + brs_model_id = "huggingface-llm-gemma-2b-instruct" + _setup_test_hub_with_reference(brs_model_id) + + brs_model = JumpStartModel( + model_id=brs_model_id, + hub_name=TEST_HUB_WITH_REFERENCE, + role=session.get_caller_identity_arn(), + sagemaker_session=session, + ) + + predictor = brs_model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + accept_eula=True, + ) + + endpoint_arn = ( + f"arn:aws:sagemaker:{session.boto_region_name}:" + f"{session.account_id()}:endpoint/{predictor.endpoint_name}" + ) + tags = session.list_tags(endpoint_arn) + + predictor.delete_predictor() # Cleanup before assertions + _teardown_test_hub_with_reference(brs_model_id) + + expected_tag = {"Key": "sagemaker-sdk:bedrock", "Value": "compatible"} + assert expected_tag in tags diff --git a/tests/integ/sagemaker/jumpstart/private_hub/__init__.py b/tests/integ/sagemaker/jumpstart/private_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py new file mode 100644 index 0000000000..a6e33f1bdf --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py @@ -0,0 +1,204 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import time + +import pytest +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.hub.hub import Hub + +from sagemaker.jumpstart.estimator import JumpStartEstimator +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket + +from tests.integ.sagemaker.jumpstart.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, + ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, + JUMPSTART_TAG, +) +from tests.integ.sagemaker.jumpstart.utils import ( + get_public_hub_model_arn, + get_sm_session, + with_exponential_backoff, + get_training_dataset_for_model_and_version, +) + +MAX_INIT_TIME_SECONDS = 5 + +TEST_MODEL_IDS = { + "huggingface-spc-bert-base-cased", + "meta-textgeneration-llama-2-7b", + "catboost-regression-model", +} + + +@with_exponential_backoff() +def create_model_reference(hub_instance, model_arn): + try: + hub_instance.create_model_reference(model_arn=model_arn) + except Exception: + pass + + +@pytest.fixture(scope="session") +def add_model_references(): + # Create Model References to test in Hub + hub_instance = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + for model in TEST_MODEL_IDS: + model_arn = get_public_hub_model_arn(hub_instance, model) + create_model_reference(hub_instance, model_arn) + + +def test_jumpstart_hub_estimator(setup, add_model_references): + model_id, model_version = "huggingface-spc-bert-base-cased", "*" + + estimator = JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + estimator.fit( + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + # test that we can create a JumpStartEstimator from existing job with `attach` + estimator = JumpStartEstimator.attach( + training_job_name=estimator.latest_training_job.name, + model_id=model_id, + model_version=model_version, + ) + + # uses ml.p3.2xlarge instance + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + response = predictor.predict(["hello", "world"]) + + assert response is not None + + +def test_jumpstart_hub_estimator_with_session(setup, add_model_references): + + model_id, model_version = "huggingface-spc-bert-base-cased", "*" + + sagemaker_session = get_sm_session() + + estimator = JumpStartEstimator( + model_id=model_id, + role=sagemaker_session.get_caller_identity_arn(), + sagemaker_session=sagemaker_session, + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + estimator.fit( + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + # test that we can create a JumpStartEstimator from existing job with `attach` + estimator = JumpStartEstimator.attach( + training_job_name=estimator.latest_training_job.name, + model_id=model_id, + model_version=model_version, + sagemaker_session=get_sm_session(), + ) + + # uses ml.p3.2xlarge instance + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + response = predictor.predict(["hello", "world"]) + + assert response is not None + + +def test_jumpstart_hub_gated_estimator_with_eula(setup, add_model_references): + + model_id, model_version = "meta-textgeneration-llama-2-7b", "*" + + estimator = JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + estimator.fit( + accept_eula=True, + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + }, + ) + + predictor = estimator.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + payload = { + "inputs": "some-payload", + "parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, + } + + response = predictor.predict(payload, custom_attributes="accept_eula=true") + + assert response is not None + + +def test_jumpstart_hub_gated_estimator_without_eula(setup, add_model_references): + + model_id, model_version = "meta-textgeneration-llama-2-7b", "*" + + estimator = JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + with pytest.raises(Exception): + estimator.fit( + inputs={ + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", + } + ) + + +def test_instantiating_estimator(setup, add_model_references): + + model_id = "catboost-regression-model" + + start_time = time.perf_counter() + + JumpStartEstimator( + model_id=model_id, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + elapsed_time = time.perf_counter() - start_time + + assert elapsed_time <= MAX_INIT_TIME_SECONDS diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/__init__.py b/tests/integ/sagemaker/jumpstart/private_hub/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py new file mode 100644 index 0000000000..c7e039693b --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py @@ -0,0 +1,192 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import time + +import pytest +from sagemaker.enums import EndpointType +from sagemaker.jumpstart.hub.hub import Hub +from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs +from sagemaker.predictor import retrieve_default + +import tests.integ + +from sagemaker.jumpstart.model import JumpStartModel +from tests.integ.sagemaker.jumpstart.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, + ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID, + JUMPSTART_TAG, +) +from tests.integ.sagemaker.jumpstart.utils import ( + get_public_hub_model_arn, + get_sm_session, + with_exponential_backoff, +) + +MAX_INIT_TIME_SECONDS = 5 + +TEST_MODEL_IDS = { + "catboost-classification-model", + "huggingface-txt2img-conflictx-complex-lineart", + "meta-textgeneration-llama-2-7b", + "meta-textgeneration-llama-3-2-1b", + "catboost-regression-model", +} + + +@with_exponential_backoff() +def create_model_reference(hub_instance, model_arn): + try: + hub_instance.create_model_reference(model_arn=model_arn) + except Exception: + pass + + +@pytest.fixture(scope="session") +def add_model_references(): + # Create Model References to test in Hub + hub_instance = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + for model in TEST_MODEL_IDS: + model_arn = get_public_hub_model_arn(hub_instance, model) + create_model_reference(hub_instance, model_arn) + + +def test_jumpstart_hub_model(setup, add_model_references): + + model_id = "catboost-classification-model" + + sagemaker_session = get_sm_session() + + model = JumpStartModel( + model_id=model_id, + role=sagemaker_session.get_caller_identity_arn(), + sagemaker_session=sagemaker_session, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + predictor = model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name) + + +def test_jumpstart_hub_model_with_default_session(setup, add_model_references): + model_version = "*" + hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] + + model_id = "catboost-classification-model" + + sagemaker_session = get_sm_session() + + model = JumpStartModel(model_id=model_id, model_version=model_version, hub_name=hub_name) + + predictor = model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name) + + +def test_jumpstart_hub_gated_model(setup, add_model_references): + + model_id = "meta-textgeneration-llama-3-2-1b" + + model = JumpStartModel( + model_id=model_id, + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + predictor = model.deploy( + accept_eula=True, + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + payload = model.retrieve_example_payload() + + response = predictor.predict(payload) + + assert response is not None + + +@pytest.mark.skip(reason="blocking PR checks and release pipeline.") +def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references): + + model_id = "meta-textgeneration-llama-3-2-1b" + + hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME] + + region = tests.integ.test_region() + + sagemaker_session = get_sm_session() + + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + + model = JumpStartModel( + model_id=model_id, + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=sagemaker_session, + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + model.deploy( + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + accept_eula=True, + endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, + ) + + predictor = retrieve_default( + endpoint_name=model.endpoint_name, + sagemaker_session=sagemaker_session, + tolerate_vulnerable_model=True, + hub_arn=hub_arn, + ) + + payload = model.retrieve_example_payload() + + response = predictor.predict(payload) + + assert response is not None + + model = JumpStartModel.attach( + predictor.endpoint_name, sagemaker_session=sagemaker_session, hub_name=hub_name + ) + assert model.model_id == model_id + assert model.endpoint_name == predictor.endpoint_name + assert model.inference_component_name == predictor.component_name + + +def test_instantiating_model(setup, add_model_references): + + model_id = "catboost-regression-model" + + start_time = time.perf_counter() + + JumpStartModel( + model_id=model_id, + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], + ) + + elapsed_time = time.perf_counter() - start_time + + assert elapsed_time <= MAX_INIT_TIME_SECONDS diff --git a/tests/integ/sagemaker/jumpstart/private_hub/test_hub.py b/tests/integ/sagemaker/jumpstart/private_hub/test_hub.py new file mode 100644 index 0000000000..db5d868c06 --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/private_hub/test_hub.py @@ -0,0 +1,53 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from sagemaker.jumpstart.hub.hub import Hub + +from tests.integ.sagemaker.jumpstart.utils import ( + get_sm_session, +) +from tests.integ.sagemaker.jumpstart.utils import ( + get_test_suite_id, +) +from tests.integ.sagemaker.jumpstart.constants import ( + HUB_NAME_PREFIX, +) + + +@pytest.fixture +def hub_instance(): + HUB_NAME = f"{HUB_NAME_PREFIX}-{get_test_suite_id()}" + hub = Hub(HUB_NAME, sagemaker_session=get_sm_session()) + yield hub + + +@pytest.mark.skip +def test_private_hub(setup, hub_instance): + # Createhub + create_hub_response = hub_instance.create( + description="This is a Test Private Hub.", + display_name="PySDK integration tests Hub", + search_keywords=["jumpstart-sdk-integ-test"], + ) + + # Create Hub Verifications + assert create_hub_response is not None + + # Describe Hub + hub_description = hub_instance.describe() + assert hub_description is not None + + # Delete Hub + delete_hub_response = hub_instance.delete() + assert delete_hub_response is not None diff --git a/tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py b/tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py new file mode 100644 index 0000000000..04b945a457 --- /dev/null +++ b/tests/integ/sagemaker/jumpstart/private_hub/test_hub_content.py @@ -0,0 +1,46 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import os +from sagemaker.jumpstart.hub.hub import Hub + +from sagemaker.jumpstart.hub.interfaces import DescribeHubContentResponse +from tests.integ.sagemaker.jumpstart.utils import ( + get_sm_session, +) +from tests.integ.sagemaker.jumpstart.utils import get_public_hub_model_arn +from tests.integ.sagemaker.jumpstart.constants import ( + ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME, +) + + +def test_hub_model_reference(setup): + model_id = "meta-textgenerationneuron-llama-3-2-1b-instruct" + + hub_instance = Hub( + hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session() + ) + + create_model_response = hub_instance.create_model_reference( + model_arn=get_public_hub_model_arn(hub_instance, model_id) + ) + assert create_model_response is not None + + describe_model_response = hub_instance.describe_model(model_name=model_id) + assert describe_model_response is not None + assert isinstance(describe_model_response, DescribeHubContentResponse) + assert describe_model_response.hub_content_name == model_id + assert describe_model_response.hub_content_type == "ModelReference" + + delete_model_response = hub_instance.delete_model_reference(model_name=model_id) + assert delete_model_response is not None diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py index 4ad58153e9..7fbac1cdba 100644 --- a/tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py +++ b/tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py @@ -77,6 +77,10 @@ def package_artifacts(self): self.model_name = self.get_model_name() + if self.script_uri is None: + print("No script uri provided. Not performing prepack") + return self.model_uri + cache_bucket_uri = f"s3://{get_test_artifact_bucket()}" repacked_model_uri = "/".join( [ @@ -147,16 +151,26 @@ def get_model_name(self) -> str: return f"{non_timestamped_name}{self.suffix}" def create_model(self) -> None: + primary_container = { + "Image": self.image_uri, + "Mode": "SingleModel", + "Environment": self.environment_variables, + } + if self.repacked_model_uri.endswith(".tar.gz"): + primary_container["ModelDataUrl"] = self.repacked_model_uri + else: + primary_container["ModelDataSource"] = { + "S3DataSource": { + "S3Uri": self.repacked_model_uri, + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } self.sagemaker_client.create_model( ModelName=self.model_name, EnableNetworkIsolation=True, ExecutionRoleArn=self.execution_role, - PrimaryContainer={ - "Image": self.image_uri, - "ModelDataUrl": self.repacked_model_uri, - "Mode": "SingleModel", - "Environment": self.environment_variables, - }, + PrimaryContainer=primary_container, ) def create_endpoint_config(self) -> None: diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py index 550e2481cd..5f23428908 100644 --- a/tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py +++ b/tests/integ/sagemaker/jumpstart/retrieve_uri/test_inference.py @@ -17,7 +17,6 @@ InferenceJobLauncher, ) from sagemaker import environment_variables, image_uris -from sagemaker import script_uris from sagemaker import model_uris from tests.integ.sagemaker.jumpstart.constants import InferenceTabularDataname @@ -31,8 +30,8 @@ def test_jumpstart_inference_retrieve_functions(setup): - model_id, model_version = "catboost-classification-model", "1.0.0" - instance_type = "ml.m5.xlarge" + model_id, model_version = "catboost-classification-model", "2.1.6" + instance_type = "ml.m5.4xlarge" print("Starting inference...") @@ -46,13 +45,6 @@ def test_jumpstart_inference_retrieve_functions(setup): tolerate_vulnerable_model=True, ) - script_uri = script_uris.retrieve( - model_id=model_id, - model_version=model_version, - script_scope="inference", - tolerate_vulnerable_model=True, - ) - model_uri = model_uris.retrieve( model_id=model_id, model_version=model_version, @@ -68,7 +60,7 @@ def test_jumpstart_inference_retrieve_functions(setup): inference_job = InferenceJobLauncher( image_uri=image_uri, - script_uri=script_uri, + script_uri=None, model_uri=model_uri, instance_type=instance_type, base_name="catboost", diff --git a/tests/integ/sagemaker/jumpstart/retrieve_uri/test_transfer_learning.py b/tests/integ/sagemaker/jumpstart/retrieve_uri/test_transfer_learning.py index 4e65cc5b58..7cb0f34fbf 100644 --- a/tests/integ/sagemaker/jumpstart/retrieve_uri/test_transfer_learning.py +++ b/tests/integ/sagemaker/jumpstart/retrieve_uri/test_transfer_learning.py @@ -33,7 +33,7 @@ def test_jumpstart_transfer_learning_retrieve_functions(setup): - model_id, model_version = "huggingface-spc-bert-base-cased", "1.0.0" + model_id, model_version = "huggingface-spc-bert-base-cased", "2.0.3" training_instance_type = "ml.p3.2xlarge" inference_instance_type = "ml.p2.xlarge" diff --git a/tests/integ/sagemaker/jumpstart/utils.py b/tests/integ/sagemaker/jumpstart/utils.py index 0f2fd01572..d439ef7e95 100644 --- a/tests/integ/sagemaker/jumpstart/utils.py +++ b/tests/integ/sagemaker/jumpstart/utils.py @@ -14,6 +14,8 @@ import functools import json +import random +import time import uuid from typing import Any, Dict, List, Tuple import boto3 @@ -21,6 +23,7 @@ import os from botocore.config import Config +from botocore.exceptions import ClientError import pytest @@ -32,6 +35,7 @@ ) from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.jumpstart.hub.hub import Hub from sagemaker.session import Session @@ -49,23 +53,18 @@ def get_sm_session() -> Session: return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)) -# def get_sm_session_with_override() -> Session: -# # [TODO]: Remove service endpoint override before GA -# # boto3.set_stream_logger(name='botocore', level=logging.DEBUG) -# boto_session = boto3.Session(region_name="us-west-2") -# sagemaker = boto3.client( -# service_name="sagemaker-internal", -# endpoint_url="https://sagemaker.beta.us-west-2.ml-platform.aws.a2z.com", -# ) -# sagemaker_runtime = boto3.client( -# service_name="runtime.maeve", -# endpoint_url="https://maeveruntime.beta.us-west-2.ml-platform.aws.a2z.com", -# ) -# return Session( -# boto_session=boto_session, -# sagemaker_client=sagemaker, -# sagemaker_runtime_client=sagemaker_runtime, -# ) +def get_sm_session_with_override() -> Session: + # [TODO]: Remove service endpoint override before GA + # boto3.set_stream_logger(name='botocore', level=logging.DEBUG) + boto_session = boto3.Session(region_name="us-west-2") + sagemaker = boto3.client( + service_name="sagemaker", + endpoint_url="https://sagemaker.gamma.us-west-2.ml-platform.aws.a2z.com", + ) + return Session( + boto_session=boto_session, + sagemaker_client=sagemaker, + ) def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict: @@ -115,6 +114,41 @@ def download_file(local_download_path, s3_bucket, s3_key, s3_client) -> None: s3_client.download_file(s3_bucket, s3_key, local_download_path) +def get_public_hub_model_arn(hub: Hub, model_id: str) -> str: + filter_value = f"model_id == {model_id}" + response = hub.list_sagemaker_public_hub_models(filter=filter_value) + + models = response["hub_content_summaries"] + + return models[0]["hub_content_arn"] + + +def with_exponential_backoff(max_retries=5, initial_delay=1, max_delay=60): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + retries = 0 + while True: + try: + return func(*args, **kwargs) + except ClientError as e: + if retries >= max_retries or e.response["Error"]["Code"] not in [ + "ThrottlingException", + "TooManyRequestsException", + ]: + raise + delay = min(initial_delay * (2**retries) + random.random(), max_delay) + print( + f"Retrying {func.__name__} in {delay:.2f} seconds... (Attempt {retries + 1}/{max_retries})" + ) + time.sleep(delay) + retries += 1 + + return wrapper + + return decorator + + class EndpointInvoker: def __init__( self, diff --git a/tests/integ/sagemaker/modules/__init__.py b/tests/integ/sagemaker/modules/__init__.py new file mode 100644 index 0000000000..9d8bffee3f --- /dev/null +++ b/tests/integ/sagemaker/modules/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Placeholder docstring""" diff --git a/tests/integ/sagemaker/modules/conftest.py b/tests/integ/sagemaker/modules/conftest.py new file mode 100644 index 0000000000..d6d3877de4 --- /dev/null +++ b/tests/integ/sagemaker/modules/conftest.py @@ -0,0 +1,40 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test image builder""" +from __future__ import absolute_import + +import pytest + +import os +import boto3 +from sagemaker.modules import Session + +DEFAULT_REGION = "us-west-2" + + +@pytest.fixture(scope="module") +def modules_sagemaker_session(): + region = os.environ.get("AWS_DEFAULT_REGION") + if not region: + os.environ["AWS_DEFAULT_REGION"] = DEFAULT_REGION + region_manual_set = True + else: + region_manual_set = False + + boto_session = boto3.Session(region_name=os.environ["AWS_DEFAULT_REGION"]) + sagemaker_session = Session(boto_session=boto_session) + + yield sagemaker_session + + if region_manual_set and "AWS_DEFAULT_REGION" in os.environ: + del os.environ["AWS_DEFAULT_REGION"] diff --git a/tests/integ/sagemaker/modules/train/test_local_model_trainer.py b/tests/integ/sagemaker/modules/train/test_local_model_trainer.py new file mode 100644 index 0000000000..7947b2fc87 --- /dev/null +++ b/tests/integ/sagemaker/modules/train/test_local_model_trainer.py @@ -0,0 +1,224 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing peCWDissions and limitations under the License. +"""This module contains code to test image builder with local mode""" +from __future__ import absolute_import +import os +import errno + +import shutil +import tempfile + +from tests.integ import DATA_DIR +import tests.integ.lock as lock + +from sagemaker.modules.configs import Compute, InputData, SourceCode +from sagemaker.modules.distributed import Torchrun +from sagemaker.modules.train.model_trainer import Mode, ModelTrainer +import subprocess + +DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310" +CWD = os.getcwd() +SOURCE_DIR = os.path.join(DATA_DIR, "modules/local_script") +LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_local_mode_lock") + + +def delete_local_path(path): + try: + if os.path.exists(path) and os.path.isdir(path): + shutil.rmtree(path) + print(f"Removed directory: {path}") + else: + print(f"Directory does not exist: {path}") + except OSError as exc: + # on Linux, when docker writes to any mounted volume, it uses the container's user. In most + # cases this is root. When the container exits and we try to delete them we can't because + # root owns those files. We expect this to happen, so we handle EACCESS. Any other error + # we will raise the exception up. + if exc.errno == errno.EACCES: + print(f"Failed to delete: {path} Please remove it manually.") + else: + print(f"Failed to delete: {path}") + raise + + +def test_single_container_local_mode_local_data(modules_sagemaker_session): + with lock.lock(LOCK_PATH): + try: + source_code = SourceCode( + source_dir=SOURCE_DIR, + entry_script="local_training_script.py", + ) + + compute = Compute( + instance_type="local_cpu", + instance_count=1, + ) + + train_data = InputData( + channel_name="train", + data_source=os.path.join(SOURCE_DIR, "data/train/"), + ) + + test_data = InputData( + channel_name="test", + data_source=os.path.join(SOURCE_DIR, "data/test/"), + ) + + model_trainer = ModelTrainer( + training_image=DEFAULT_CPU_IMAGE, + sagemaker_session=modules_sagemaker_session, + source_code=source_code, + compute=compute, + input_data_config=[train_data, test_data], + base_job_name="local_mode_single_container_local_data", + training_mode=Mode.LOCAL_CONTAINER, + ) + + model_trainer.train() + assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz")) + finally: + subprocess.run(["docker", "compose", "down", "-v"]) + directories = [ + "compressed_artifacts", + "artifacts", + "model", + "output", + ] + + for directory in directories: + path = os.path.join(CWD, directory) + delete_local_path(path) + + +def test_single_container_local_mode_s3_data(modules_sagemaker_session): + with lock.lock(LOCK_PATH): + try: + # upload local data to s3 + session = modules_sagemaker_session + bucket = session.default_bucket() + session.upload_data( + path=os.path.join(SOURCE_DIR, "data/train/"), + bucket=bucket, + key_prefix="data/train", + ) + session.upload_data( + path=os.path.join(SOURCE_DIR, "data/test/"), + bucket=bucket, + key_prefix="data/test", + ) + + source_code = SourceCode( + source_dir=SOURCE_DIR, + entry_script="local_training_script.py", + ) + + compute = Compute( + instance_type="local_cpu", + instance_count=1, + ) + + # read input data from s3 + train_data = InputData(channel_name="train", data_source=f"s3://{bucket}/data/train/") + + test_data = InputData(channel_name="test", data_source=f"s3://{bucket}/data/test/") + + model_trainer = ModelTrainer( + training_image=DEFAULT_CPU_IMAGE, + sagemaker_session=modules_sagemaker_session, + source_code=source_code, + compute=compute, + input_data_config=[train_data, test_data], + base_job_name="local_mode_single_container_s3_data", + training_mode=Mode.LOCAL_CONTAINER, + ) + + model_trainer.train() + assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz")) + finally: + subprocess.run(["docker", "compose", "down", "-v"]) + + assert not os.path.exists(os.path.join(CWD, "shared")) + assert not os.path.exists(os.path.join(CWD, "input")) + assert not os.path.exists(os.path.join(CWD, "algo-1")) + + directories = [ + "compressed_artifacts", + "artifacts", + "model", + "output", + ] + + for directory in directories: + path = os.path.join(CWD, directory) + delete_local_path(path) + + +def test_multi_container_local_mode(modules_sagemaker_session): + with lock.lock(LOCK_PATH): + try: + source_code = SourceCode( + source_dir=SOURCE_DIR, + entry_script="local_training_script.py", + ) + + distributed = Torchrun( + process_count_per_node=1, + ) + + compute = Compute( + instance_type="local_cpu", + instance_count=2, + ) + + train_data = InputData( + channel_name="train", + data_source=os.path.join(SOURCE_DIR, "data/train/"), + ) + + test_data = InputData( + channel_name="test", + data_source=os.path.join(SOURCE_DIR, "data/test/"), + ) + + model_trainer = ModelTrainer( + training_image=DEFAULT_CPU_IMAGE, + sagemaker_session=modules_sagemaker_session, + source_code=source_code, + distributed=distributed, + compute=compute, + input_data_config=[train_data, test_data], + base_job_name="local_mode_multi_container", + training_mode=Mode.LOCAL_CONTAINER, + ) + + model_trainer.train() + assert os.path.exists(os.path.join(CWD, "compressed_artifacts/model.tar.gz")) + + finally: + subprocess.run(["docker", "compose", "down", "-v"]) + + assert not os.path.exists(os.path.join(CWD, "shared")) + assert not os.path.exists(os.path.join(CWD, "input")) + assert not os.path.exists(os.path.join(CWD, "algo-1")) + assert not os.path.exists(os.path.join(CWD, "algo-2")) + + directories = [ + "compressed_artifacts", + "artifacts", + "model", + "output", + ] + + for directory in directories: + path = os.path.join(CWD, directory) + delete_local_path(path) diff --git a/tests/integ/sagemaker/modules/train/test_model_trainer.py b/tests/integ/sagemaker/modules/train/test_model_trainer.py new file mode 100644 index 0000000000..332b536d77 --- /dev/null +++ b/tests/integ/sagemaker/modules/train/test_model_trainer.py @@ -0,0 +1,178 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test image builder""" +from __future__ import absolute_import + +from tests.integ import DATA_DIR + +from sagemaker.modules.train import ModelTrainer +from sagemaker.modules.configs import SourceCode, Compute +from sagemaker.modules.distributed import MPI, Torchrun, DistributedConfig + +EXPECTED_HYPERPARAMETERS = { + "integer": 1, + "boolean": True, + "float": 3.14, + "string": "Hello World", + "list": [1, 2, 3], + "dict": { + "string": "value", + "integer": 3, + "float": 3.14, + "list": [1, 2, 3], + "dict": {"key": "value"}, + "boolean": True, + }, +} + +PARAM_SCRIPT_SOURCE_DIR = f"{DATA_DIR}/modules/params_script" +PARAM_SCRIPT_SOURCE_CODE = SourceCode( + source_dir=PARAM_SCRIPT_SOURCE_DIR, + requirements="requirements.txt", + entry_script="train.py", +) + +DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310" + +TAR_FILE_SOURCE_DIR = f"{DATA_DIR}/modules/script_mode/code.tar.gz" +TAR_FILE_SOURCE_CODE = SourceCode( + source_dir=TAR_FILE_SOURCE_DIR, + requirements="requirements.txt", + entry_script="custom_script.py", +) + + +def test_source_dir_local_tar_file(modules_sagemaker_session): + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + source_code=TAR_FILE_SOURCE_CODE, + base_job_name="source_dir_local_tar_file", + ) + + model_trainer.train() + + +def test_hp_contract_basic_py_script(modules_sagemaker_session): + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=EXPECTED_HYPERPARAMETERS, + source_code=PARAM_SCRIPT_SOURCE_CODE, + base_job_name="hp-contract-basic-py-script", + ) + + model_trainer.train() + + +def test_hp_contract_basic_sh_script(modules_sagemaker_session): + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/params_script", + requirements="requirements.txt", + entry_script="train.sh", + ) + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=EXPECTED_HYPERPARAMETERS, + source_code=source_code, + base_job_name="hp-contract-basic-sh-script", + ) + + model_trainer.train() + + +def test_hp_contract_mpi_script(modules_sagemaker_session): + compute = Compute(instance_type="ml.m5.xlarge", instance_count=2) + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + compute=compute, + hyperparameters=EXPECTED_HYPERPARAMETERS, + source_code=PARAM_SCRIPT_SOURCE_CODE, + distributed=MPI(), + base_job_name="hp-contract-mpi-script", + ) + + model_trainer.train() + + +def test_hp_contract_torchrun_script(modules_sagemaker_session): + compute = Compute(instance_type="ml.m5.xlarge", instance_count=2) + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + compute=compute, + hyperparameters=EXPECTED_HYPERPARAMETERS, + source_code=PARAM_SCRIPT_SOURCE_CODE, + distributed=Torchrun(), + base_job_name="hp-contract-torchrun-script", + ) + + model_trainer.train() + + +def test_hp_contract_hyperparameter_json(modules_sagemaker_session): + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.json", + source_code=PARAM_SCRIPT_SOURCE_CODE, + base_job_name="hp-contract-hyperparameter-json", + ) + assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS + model_trainer.train() + + +def test_hp_contract_hyperparameter_yaml(modules_sagemaker_session): + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=f"{PARAM_SCRIPT_SOURCE_DIR}/hyperparameters.yaml", + source_code=PARAM_SCRIPT_SOURCE_CODE, + base_job_name="hp-contract-hyperparameter-yaml", + ) + assert model_trainer.hyperparameters == EXPECTED_HYPERPARAMETERS + model_trainer.train() + + +def test_custom_distributed_driver(modules_sagemaker_session): + class CustomDriver(DistributedConfig): + process_count_per_node: int = None + + @property + def driver_dir(self) -> str: + return f"{DATA_DIR}/modules/custom_drivers" + + @property + def driver_script(self) -> str: + return "driver.py" + + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/scripts", + entry_script="entry_script.py", + ) + + hyperparameters = {"epochs": 10} + + custom_driver = CustomDriver(process_count_per_node=2) + + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=hyperparameters, + source_code=source_code, + distributed=custom_driver, + base_job_name="custom-distributed-driver", + ) + model_trainer.train() diff --git a/tests/integ/sagemaker/remote_function/test_decorator.py b/tests/integ/sagemaker/remote_function/test_decorator.py index 63ced1dd9c..fa55d7dfa7 100644 --- a/tests/integ/sagemaker/remote_function/test_decorator.py +++ b/tests/integ/sagemaker/remote_function/test_decorator.py @@ -818,3 +818,26 @@ def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container): f"--rm {auto_capture_test_container}" ) subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8") + + +def test_decorator_torchrun( + sagemaker_session, + dummy_container_without_error, + gpu_instance_type, + use_torchrun=False, + use_mpirun=False, +): + @remote( + role=ROLE, + image_uri=dummy_container_without_error, + instance_type=gpu_instance_type, + sagemaker_session=sagemaker_session, + keep_alive_period_in_seconds=60, + use_torchrun=use_torchrun, + use_mpirun=use_mpirun, + ) + def divide(x, y): + return x / y + + assert divide(10, 2) == 5 + assert divide(20, 2) == 10 diff --git a/tests/integ/sagemaker/serve/conftest.py b/tests/integ/sagemaker/serve/conftest.py index a1086afea7..5eb3a2ea11 100644 --- a/tests/integ/sagemaker/serve/conftest.py +++ b/tests/integ/sagemaker/serve/conftest.py @@ -10,64 +10,48 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -# from __future__ import absolute_import +from __future__ import absolute_import -# import os -# import pytest -# import platform -# import collections -# from numpy import loadtxt -# from sagemaker.serve.spec.inference_spec import InferenceSpec +import pytest +import os +import boto3 +import sagemaker +import sagemaker_core.helper.session_helper as core_session -# if platform.python_version_tuple()[1] == "8": -# from xgboost import XGBClassifier -# from sklearn.model_selection import train_test_split +DEFAULT_REGION = "us-west-2" -# from tests.integ.sagemaker.serve.constants import XGB_RESOURCE_DIR +@pytest.fixture(scope="module") +def mb_sagemaker_session(): + region = os.environ.get("AWS_DEFAULT_REGION") + if not region: + os.environ["AWS_DEFAULT_REGION"] = DEFAULT_REGION + region_manual_set = True + else: + region_manual_set = True -# XgbTestSplit = collections.namedtuple("XgbTrainTestSplit", "x_test y_test") + boto_session = boto3.Session(region_name=os.environ["AWS_DEFAULT_REGION"]) + sagemaker_session = sagemaker.Session(boto_session=boto_session) + yield sagemaker_session -# @pytest.fixture(scope="session") -# def loaded_xgb_model(): -# model = XGBClassifier() -# model.load_model(XGB_RESOURCE_DIR + "/model.xgb") -# return model + if region_manual_set and "AWS_DEFAULT_REGION" in os.environ: + del os.environ["AWS_DEFAULT_REGION"] -# @pytest.fixture(scope="session") -# def xgb_inference_spec(): -# class MyXGBoostModel(InferenceSpec): -# def load(self, model_dir: str): -# model = XGBClassifier() -# model.load_model(model_dir + "/model.xgb") -# return model +@pytest.fixture(scope="module") +def mb_sagemaker_core_session(): + region = os.environ.get("AWS_DEFAULT_REGION") + if not region: + os.environ["AWS_DEFAULT_REGION"] = DEFAULT_REGION + region_manual_set = True + else: + region_manual_set = True -# def invoke( -# self, -# input: object, -# model: object, -# ): -# y_pred = model.predict(input) -# predictions = [round(value) for value in y_pred] -# return predictions + boto_session = boto3.Session(region_name=os.environ["AWS_DEFAULT_REGION"]) + sagemaker_session = core_session.Session(boto_session=boto_session) -# return MyXGBoostModel() + yield sagemaker_session - -# @pytest.fixture(scope="session") -# def xgb_test_sets(): -# dataset = loadtxt( -# os.path.join(XGB_RESOURCE_DIR, "classification_training_data.data.csv"), delimiter="," -# ) - -# X = dataset[:, 0:8] -# Y = dataset[:, 8] - -# seed = 7 -# test_size = 0.33 - -# _, x_test, _, y_test = train_test_split(X, Y, test_size=test_size, random_state=seed) - -# return XgbTestSplit(x_test, y_test) + if region_manual_set and "AWS_DEFAULT_REGION" in os.environ: + del os.environ["AWS_DEFAULT_REGION"] diff --git a/tests/integ/sagemaker/serve/constants.py b/tests/integ/sagemaker/serve/constants.py index 794f7333a3..3f25f6a575 100644 --- a/tests/integ/sagemaker/serve/constants.py +++ b/tests/integ/sagemaker/serve/constants.py @@ -25,6 +25,7 @@ PYTHON_VERSION_IS_NOT_38 = platform.python_version_tuple()[1] != "8" PYTHON_VERSION_IS_NOT_310 = platform.python_version_tuple()[1] != "10" +PYTHON_VERSION_IS_NOT_312 = platform.python_version_tuple()[1] != "12" XGB_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "xgboost") PYTORCH_SQUEEZENET_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "pytorch") @@ -32,6 +33,7 @@ DATA_DIR, "serve_resources", "mlflow", "pytorch" ) XGBOOST_MLFLOW_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "mlflow", "xgboost") +TENSORFLOW_MLFLOW_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "mlflow", "tensorflow") TF_EFFICIENT_RESOURCE_DIR = os.path.join(DATA_DIR, "serve_resources", "tensorflow") HF_DIR = os.path.join(DATA_DIR, "serve_resources", "hf") diff --git a/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py b/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py new file mode 100644 index 0000000000..a0de64225d --- /dev/null +++ b/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py @@ -0,0 +1,227 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import uuid +from typing import Generator + +import numpy as np +import pandas as pd +import pytest +from sagemaker_core.main.resources import TrainingJob +from sagemaker_core.main.shapes import ( + AlgorithmSpecification, + Channel, + DataSource, + OutputDataConfig, + ResourceConfig, + S3DataSource, + StoppingCondition, +) +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from xgboost import XGBClassifier + +from sagemaker import get_execution_role +from sagemaker.async_inference import AsyncInferenceConfig +from sagemaker.s3_utils import s3_path_join +from sagemaker.serve import InferenceSpec, SchemaBuilder +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig +from tests.integ.utils import cleanup_model_resources + + +@pytest.fixture(autouse=True) +def cleanup_endpoints(mb_sagemaker_session) -> Generator[None, None, None]: + """Clean up any existing endpoints before and after tests.""" + sagemaker_client = mb_sagemaker_session.sagemaker_client + + # Pre-test cleanup + try: + endpoints = sagemaker_client.list_endpoints() + for endpoint in endpoints["Endpoints"]: + try: + sagemaker_client.delete_endpoint(EndpointName=endpoint["EndpointName"]) + sagemaker_client.delete_endpoint_config( + EndpointConfigName=endpoint["EndpointConfigName"] + ) + except Exception as e: + print(f"Error cleaning up endpoint {endpoint['EndpointName']}: {e}") + except Exception as e: + print(f"Error listing endpoints: {e}") + + yield + + # Post-test cleanup + try: + endpoints = sagemaker_client.list_endpoints() + for endpoint in endpoints["Endpoints"]: + try: + sagemaker_client.delete_endpoint(EndpointName=endpoint["EndpointName"]) + sagemaker_client.delete_endpoint_config( + EndpointConfigName=endpoint["EndpointConfigName"] + ) + except Exception as e: + print(f"Error cleaning up endpoint {endpoint['EndpointName']}: {e}") + except Exception as e: + print(f"Error listing endpoints: {e}") + + +@pytest.fixture(scope="module") +def xgboost_model_builder(mb_sagemaker_session): + sagemaker_session = mb_sagemaker_session + role = get_execution_role(sagemaker_session=sagemaker_session) + bucket = sagemaker_session.default_bucket() + + # Get IRIS Data + iris = load_iris() + iris_df = pd.DataFrame(iris.data, columns=iris.feature_names) + iris_df["target"] = iris.target + + # Prepare Data + os.makedirs("data", exist_ok=True) + + iris_df = iris_df[["target"] + [col for col in iris_df.columns if col != "target"]] + + train_data, test_data = train_test_split(iris_df, test_size=0.2, random_state=42) + + train_data.to_csv("data/train.csv", index=False, header=False) + test_data.to_csv("data/test.csv", index=False, header=False) + + # Remove the target column from the testing data. We will use this to call invoke_endpoint later + test_data.drop("target", axis=1) + + prefix = "DEMO-scikit-iris" + TRAIN_DATA = "train.csv" + DATA_DIRECTORY = "data" + + sagemaker_session.upload_data( + DATA_DIRECTORY, bucket=bucket, key_prefix="{}/{}".format(prefix, DATA_DIRECTORY) + ) + + s3_input_path = "s3://{}/{}/data/{}".format(bucket, prefix, TRAIN_DATA) + s3_output_path = "s3://{}/{}/output".format(bucket, prefix) + + print(s3_input_path) + print(s3_output_path) + + image = "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:1" + + class XGBoostSpec(InferenceSpec): + def load(self, model_dir: str): + print(model_dir) + model = XGBClassifier() + model.load_model(model_dir + "/xgboost-model") + return model + + def invoke(self, input_object: object, model: object): + prediction_probabilities = model.predict_proba(input_object) + predictions = np.argmax(prediction_probabilities, axis=1) + return predictions + + data = {"Name": ["Alice", "Bob", "Charlie"]} + df = pd.DataFrame(data) + training_job_name = str(uuid.uuid4()) + schema_builder = SchemaBuilder(sample_input=df, sample_output=df) + + training_job = TrainingJob.create( + training_job_name=training_job_name, + hyper_parameters={ + "objective": "multi:softmax", + "num_class": "3", + "num_round": "10", + "eval_metric": "merror", + }, + algorithm_specification=AlgorithmSpecification( + training_image=image, training_input_mode="File" + ), + role_arn=role, + input_data_config=[ + Channel( + channel_name="train", + content_type="csv", + compression_type="None", + record_wrapper_type="None", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=s3_input_path, + s3_data_distribution_type="FullyReplicated", + ) + ), + ) + ], + output_data_config=OutputDataConfig(s3_output_path=s3_output_path), + resource_config=ResourceConfig( + instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=30 + ), + stopping_condition=StoppingCondition(max_runtime_in_seconds=600), + ) + training_job.wait() + + xgboost_model_builder = ModelBuilder( + name="ModelBuilderTest", + model_path=training_job.model_artifacts.s3_model_artifacts, + role_arn=role, + inference_spec=XGBoostSpec(), + image_uri=image, + schema_builder=schema_builder, + instance_type="ml.c6i.xlarge", + ) + xgboost_model_builder.build() + return xgboost_model_builder + + +def test_real_time_deployment(xgboost_model_builder): + real_time_predictor = xgboost_model_builder.deploy( + endpoint_name=f"test-{uuid.uuid1().hex}", initial_instance_count=1 + ) + + assert real_time_predictor is not None + cleanup_model_resources( + sagemaker_session=xgboost_model_builder.sagemaker_session, + model_name=xgboost_model_builder.built_model.name, + endpoint_name=xgboost_model_builder.built_model.endpoint_name, + ) + + +def test_serverless_deployment(xgboost_model_builder): + serverless_predictor = xgboost_model_builder.deploy( + endpoint_name=f"test1-{uuid.uuid1().hex}", inference_config=ServerlessInferenceConfig() + ) + + assert serverless_predictor is not None + cleanup_model_resources( + sagemaker_session=xgboost_model_builder.sagemaker_session, + model_name=xgboost_model_builder.built_model.name, + endpoint_name=xgboost_model_builder.built_model.endpoint_name, + ) + + +def test_async_deployment(xgboost_model_builder, mb_sagemaker_session): + async_predictor = xgboost_model_builder.deploy( + endpoint_name="test2", + inference_config=AsyncInferenceConfig( + output_path=s3_path_join( + "s3://", mb_sagemaker_session.default_bucket(), "async_inference/output" + ) + ), + ) + + assert async_predictor is not None + cleanup_model_resources( + sagemaker_session=xgboost_model_builder.sagemaker_session, + model_name=xgboost_model_builder.built_model.name, + endpoint_name=xgboost_model_builder.built_model.endpoint_name, + ) diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py index a0c1673ae8..6d3e8281d5 100644 --- a/tests/integ/sagemaker/serve/test_schema_builder.py +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -33,7 +33,11 @@ def test_model_builder_happy_path_with_only_model_id_text_generation(sagemaker_session): - model_builder = ModelBuilder(model="HuggingFaceH4/zephyr-7b-beta") + model_builder = ModelBuilder( + model="HuggingFaceH4/zephyr-7b-beta", + sagemaker_session=sagemaker_session, + instance_type=None, + ) model = model_builder.build(sagemaker_session=sagemaker_session) @@ -47,7 +51,9 @@ def test_model_builder_happy_path_with_only_model_id_text_generation(sagemaker_s def test_model_builder_negative_path(sagemaker_session): # A model-task combo unsupported by both the local and remote schema fallback options. (eg: text-to-video) - model_builder = ModelBuilder(model="ByteDance/AnimateDiff-Lightning") + model_builder = ModelBuilder( + model="ByteDance/AnimateDiff-Lightning", sagemaker_session=sagemaker_session + ) with pytest.raises( TaskNotFoundException, match="Error Message: HuggingFace Schema builder samples for text-to-video could not be found locally or " @@ -86,6 +92,7 @@ def test_model_builder_happy_path_with_task_provided_local_schema_mode( model=model_id, model_metadata={"HF_TASK": task_provided}, instance_type=instance_type_provided, + sagemaker_session=sagemaker_session, ) model = model_builder.build(sagemaker_session=sagemaker_session) @@ -111,13 +118,13 @@ def test_model_builder_happy_path_with_task_provided_local_schema_mode( if container_startup_timeout: predictor = model.deploy( role=role_arn, - instance_count=1, + initial_instance_count=1, instance_type=instance_type_provided, container_startup_health_check_timeout=container_startup_timeout, ) else: predictor = model.deploy( - role=role_arn, instance_count=1, instance_type=instance_type_provided + role=role_arn, initial_instance_count=1, instance_type=instance_type_provided ) predicted_outputs = predictor.predict(inputs) @@ -162,6 +169,7 @@ def test_model_builder_happy_path_with_task_provided_remote_schema_mode( model=model_id, model_metadata={"HF_TASK": task_provided}, instance_type=instance_type_provided, + sagemaker_session=sagemaker_session, ) model = model_builder.build(sagemaker_session=sagemaker_session) @@ -181,7 +189,7 @@ def test_model_builder_happy_path_with_task_provided_remote_schema_mode( logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") predictor = model.deploy( - role=role_arn, instance_count=1, instance_type=instance_type_provided + role=role_arn, initial_instance_count=1, instance_type=instance_type_provided ) predicted_outputs = predictor.predict(inputs) @@ -217,6 +225,7 @@ def test_model_builder_with_task_provided_remote_schema_mode_asr( model=model_id, model_metadata={"HF_TASK": task_provided}, instance_type=instance_type_provided, + sagemaker_session=sagemaker_session, ) model = model_builder.build(sagemaker_session=sagemaker_session) @@ -231,7 +240,9 @@ def test_model_builder_with_task_provided_remote_schema_mode_asr( def test_model_builder_negative_path_with_invalid_task(sagemaker_session): model_builder = ModelBuilder( - model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"} + model="bert-base-uncased", + model_metadata={"HF_TASK": "invalid-task"}, + sagemaker_session=sagemaker_session, ) with pytest.raises( diff --git a/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py new file mode 100644 index 0000000000..3b59cae321 --- /dev/null +++ b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py @@ -0,0 +1,263 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from unittest.mock import MagicMock, patch, ANY + +from sagemaker.session import Session +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.resource_requirements import ResourceRequirements + +ROLE_NAME = "SageMakerRole" + + +def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_expected( + sagemaker_session, +): + with ( + patch.object(Session, "create_model", return_value="mock_model") as mock_create_model, + patch.object( + Session, "endpoint_from_production_variants" + ) as mock_endpoint_from_production_variants, + ): + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + + sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() + + schema_builder = SchemaBuilder("test", "test") + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-1-8b-instruct", + schema_builder=schema_builder, + sagemaker_session=sagemaker_session, + role_arn=role_arn, + ) + + optimized_model = model_builder.optimize( + instance_type="ml.g5.xlarge", # set to small instance in case a network call is made + speculative_decoding_config={ + "ModelProvider": "JumpStart", + "ModelID": "meta-textgeneration-llama-3-2-1b", + "AcceptEula": True, + }, + accept_eula=True, + ) + + assert not sagemaker_session.sagemaker_client.create_optimization_job.called + + optimized_model.deploy() + + mock_create_model.assert_called_once_with( + name=ANY, + role=ANY, + container_defs={ + "Image": ANY, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model/", + }, + "AdditionalModelDataSources": [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": ANY, + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"AcceptEula": True}, + }, + } + ], + "ModelDataSource": { + "S3DataSource": { + "S3Uri": ANY, + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"AcceptEula": True}, + } + }, + }, + vpc_config=None, + enable_network_isolation=True, + tags=ANY, + ) + mock_endpoint_from_production_variants.assert_called_once() + + +def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_expected( + sagemaker_session, +): + with ( + patch.object( + Session, + "wait_for_optimization_job", + return_value={"OptimizationJobName": "mock_optimization_job"}, + ), + patch.object(Session, "create_model", return_value="mock_model") as mock_create_model, + patch.object( + Session, "endpoint_from_production_variants", return_value="mock_endpoint_name" + ) as mock_endpoint_from_production_variants, + patch.object(Session, "create_inference_component") as mock_create_inference_component, + ): + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + + sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() + + schema_builder = SchemaBuilder("test", "test") + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-1-8b-instruct", + schema_builder=schema_builder, + sagemaker_session=sagemaker_session, + role_arn=role_arn, + ) + + optimized_model = model_builder.optimize( + instance_type="ml.g5.xlarge", # set to small instance in case a network call is made + sharding_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "8"}}, + accept_eula=True, + ) + + assert ( + sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelShardingConfig"]["Image"] + is not None + ) + + optimized_model.deploy( + resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8}) + ) + + mock_create_model.assert_called_once_with( + name=ANY, + role=ANY, + container_defs={ + "Image": ANY, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "OPTION_TENSOR_PARALLEL_DEGREE": "8", + }, + "ModelDataSource": { + "S3DataSource": { + "S3Uri": ANY, + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"AcceptEula": True}, + } + }, + }, + vpc_config=None, + enable_network_isolation=False, # should be set to false + tags=ANY, + ) + mock_endpoint_from_production_variants.assert_called_once_with( + name=ANY, + production_variants=ANY, + tags=ANY, + kms_key=ANY, + vpc_config=ANY, + enable_network_isolation=False, + role=ANY, + live_logging=False, # this should be set to false for IC + wait=True, + ) + mock_create_inference_component.assert_called_once() + + +def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are_expected( + sagemaker_session, +): + with ( + patch.object( + Session, + "wait_for_optimization_job", + return_value={"OptimizationJobName": "mock_optimization_job"}, + ), + patch.object(Session, "create_model", return_value="mock_model") as mock_create_model, + patch.object( + Session, "endpoint_from_production_variants", return_value="mock_endpoint_name" + ) as mock_endpoint_from_production_variants, + ): + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + + sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() + + schema_builder = SchemaBuilder("test", "test") + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-1-8b-instruct", + schema_builder=schema_builder, + sagemaker_session=sagemaker_session, + role_arn=role_arn, + ) + + optimized_model = model_builder.optimize( + instance_type="ml.g5.xlarge", # set to small instance in case a network call is made + quantization_config={ + "OverrideEnvironment": { + "OPTION_QUANTIZE": "fp8", + }, + }, + accept_eula=True, + ) + + assert ( + sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelQuantizationConfig"]["Image"] + is not None + ) + + optimized_model.deploy() + + mock_create_model.assert_called_once_with( + name=ANY, + role=ANY, + container_defs={ + "Image": ANY, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "OPTION_QUANTIZE": "fp8", + }, + "ModelDataSource": { + "S3DataSource": { + "S3Uri": ANY, + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"AcceptEula": True}, + } + }, + }, + vpc_config=None, + enable_network_isolation=True, # should be set to false + tags=ANY, + ) + mock_endpoint_from_production_variants.assert_called_once() diff --git a/tests/integ/sagemaker/serve/test_serve_js_happy.py b/tests/integ/sagemaker/serve/test_serve_js_happy.py index 7835c8ae3c..807a5ad691 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_js_happy.py @@ -12,6 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import io +import sys + import pytest from sagemaker.serve.builder.model_builder import ModelBuilder @@ -34,6 +37,14 @@ JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16" ROLE_NAME = "SageMakerRole" +SAMPLE_MMS_PROMPT = [ + "How cute your dog is!", + "Your dog is so cute.", + "The mitochondria is the powerhouse of the cell.", +] +SAMPLE_MMS_RESPONSE = {"embedding": []} +JS_MMS_MODEL_ID = "huggingface-sentencesimilarity-bge-m3" + @pytest.fixture def happy_model_builder(sagemaker_session): @@ -46,6 +57,30 @@ def happy_model_builder(sagemaker_session): ) +@pytest.fixture +def meta_textgeneration_llama_2_7b_f_schema(): + prompt = "Hello, I'm a language model," + response = "Hello, I'm a language model, and I'm here to help you with your English." + sample_input = {"inputs": prompt} + sample_output = [{"generated_text": response}] + + return SchemaBuilder( + sample_input=sample_input, + sample_output=sample_output, + ) + + +@pytest.fixture +def happy_mms_model_builder(sagemaker_session): + iam_client = sagemaker_session.boto_session.client("iam") + return ModelBuilder( + model=JS_MMS_MODEL_ID, + schema_builder=SchemaBuilder(SAMPLE_MMS_PROMPT, SAMPLE_MMS_RESPONSE), + role_arn=iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"], + sagemaker_session=sagemaker_session, + ) + + @pytest.mark.skipif( PYTHON_VERSION_IS_NOT_310, reason="The goal of these test are to test the serving components of our feature", @@ -75,3 +110,90 @@ def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type): ) if caught_ex: raise caught_ex + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these test are to test the serving components of our feature", +) +@pytest.mark.slow_test +def test_happy_mms_sagemaker_endpoint(happy_mms_model_builder, gpu_instance_type): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + model = happy_mms_model_builder.build() + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(instance_type=gpu_instance_type, endpoint_logging=False) + logger.info("Endpoint successfully deployed.") + + updated_sample_input = happy_mms_model_builder.schema_builder.sample_input + + predictor.predict(updated_sample_input) + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=happy_mms_model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + raise caught_ex + + +@pytest.mark.skipif( + True, + reason="Only enable after metadata is fully deployed.", +) +def test_js_model_with_deployment_configs( + meta_textgeneration_llama_2_7b_f_schema, + sagemaker_session, +): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-2-13b", + schema_builder=meta_textgeneration_llama_2_7b_f_schema, + ) + configs = model_builder.list_deployment_configs() + + assert len(configs) > 0 + + captured_output = io.StringIO() + sys.stdout = captured_output + model_builder.display_benchmark_metrics() + sys.stdout = sys.__stdout__ + assert captured_output.getvalue() is not None + + model_builder.set_deployment_config( + configs[0]["ConfigName"], + "ml.g5.2xlarge", + ) + model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session) + assert model.config_name == configs[0]["ConfigName"] + assert model_builder.get_deployment_config() is not None + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(accept_eula=True) + logger.info("Endpoint successfully deployed.") + + updated_sample_input = model_builder.schema_builder.sample_input + + predictor.predict(updated_sample_input) + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + raise caught_ex diff --git a/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py b/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py index e7ebd9c5bf..345d5e5af9 100644 --- a/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_mlflow_pytorch_flavor_happy.py @@ -19,6 +19,8 @@ import io import numpy as np +from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.association import Association from sagemaker.s3 import S3Uploader from sagemaker.serve.builder.model_builder import ModelBuilder, Mode from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator @@ -29,12 +31,16 @@ PYTORCH_SQUEEZENET_MLFLOW_RESOURCE_DIR, SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, # SERVE_LOCAL_CONTAINER_TIMEOUT, - PYTHON_VERSION_IS_NOT_310, + # PYTHON_VERSION_IS_NOT_310, ) from tests.integ.timeout import timeout from tests.integ.utils import cleanup_model_resources import logging +from sagemaker.serve.utils.lineage_constants import ( + MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, +) + logger = logging.getLogger(__name__) ROLE_NAME = "SageMakerRole" @@ -160,9 +166,9 @@ def model_builder(request): # ), f"{caught_ex} was thrown when running pytorch squeezenet local container test" -@pytest.mark.skipif( - PYTHON_VERSION_IS_NOT_310, # or NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE, - reason="The goal of these test are to test the serving components of our feature", +@pytest.mark.skip( + reason="Testing against Python version 310 which is not supported anymore" + " https://github.com/aws/deep-learning-containers/blob/master/available_images.md", ) def test_happy_pytorch_sagemaker_endpoint_with_torch_serve( sagemaker_session, @@ -205,6 +211,19 @@ def test_happy_pytorch_sagemaker_endpoint_with_torch_serve( predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1) logger.info("Endpoint successfully deployed.") predictor.predict(test_image) + model_data_artifact = None + for artifact in Artifact.list( + source_uri=model_builder.s3_upload_path, sagemaker_session=sagemaker_session + ): + model_data_artifact = artifact + for association in Association.list( + destination_arn=model_data_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ): + assert ( + association.source_type == MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE + ) + break except Exception as e: caught_ex = e finally: @@ -214,9 +233,4 @@ def test_happy_pytorch_sagemaker_endpoint_with_torch_serve( endpoint_name=model.endpoint_name, ) if caught_ex: - logger.exception(caught_ex) - ignore_if_worker_dies = "Worker died." in str(caught_ex) - # https://github.com/pytorch/serve/issues/3032 - assert ( - ignore_if_worker_dies - ), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test" + raise caught_ex diff --git a/tests/integ/sagemaker/serve/test_serve_mlflow_tensorflow_flavor_happy.py b/tests/integ/sagemaker/serve/test_serve_mlflow_tensorflow_flavor_happy.py new file mode 100644 index 0000000000..c25cbd7e18 --- /dev/null +++ b/tests/integ/sagemaker/serve/test_serve_mlflow_tensorflow_flavor_happy.py @@ -0,0 +1,175 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import io +import numpy as np + +from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.association import Association +from sagemaker.s3 import S3Uploader +from sagemaker.serve.builder.model_builder import ModelBuilder, Mode +from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator +import tensorflow as tf +from sklearn.datasets import fetch_california_housing + + +from tests.integ.sagemaker.serve.constants import ( + TENSORFLOW_MLFLOW_RESOURCE_DIR, + SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, + PYTHON_VERSION_IS_NOT_310, +) +from tests.integ.timeout import timeout +from tests.integ.utils import cleanup_model_resources +import logging + +from sagemaker.serve.utils.lineage_constants import ( + MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, +) + +logger = logging.getLogger(__name__) + +ROLE_NAME = "SageMakerRole" + + +@pytest.fixture +def test_data(): + dataset = fetch_california_housing(as_frame=True)["frame"] + dataset = dataset.dropna() + dataset_tf = tf.convert_to_tensor(dataset, dtype=tf.float32) + dataset_tf = dataset_tf[:50] + x_test, y_test = dataset_tf[:, :-1], dataset_tf[:, -1] + return x_test, y_test + + +@pytest.fixture +def custom_request_translator(): + class MyRequestTranslator(CustomPayloadTranslator): + def serialize_payload_to_bytes(self, payload: object) -> bytes: + return self._convert_numpy_to_bytes(payload) + + def deserialize_payload_from_stream(self, stream) -> object: + np_array = np.load(io.BytesIO(stream.read())) + return np_array + + def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes: + buffer = io.BytesIO() + np.save(buffer, np_array) + return buffer.getvalue() + + return MyRequestTranslator() + + +@pytest.fixture +def custom_response_translator(): + class MyResponseTranslator(CustomPayloadTranslator): + def serialize_payload_to_bytes(self, payload: object) -> bytes: + import numpy as np + + return self._convert_numpy_to_bytes(np.array(payload)) + + def deserialize_payload_from_stream(self, stream) -> object: + import tensorflow as tf + + return tf.convert_to_tensor(np.load(io.BytesIO(stream.read()))) + + def _convert_numpy_to_bytes(self, np_array: np.ndarray) -> bytes: + buffer = io.BytesIO() + np.save(buffer, np_array) + return buffer.getvalue() + + return MyResponseTranslator() + + +@pytest.fixture +def tensorflow_schema_builder(custom_request_translator, custom_response_translator, test_data): + input_data, output_data = test_data + return SchemaBuilder( + sample_input=input_data, + sample_output=output_data, + input_translator=custom_request_translator, + output_translator=custom_response_translator, + ) + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these test are to test the serving components of our feature", +) +def test_happy_tensorflow_sagemaker_endpoint_with_tensorflow_serving( + sagemaker_session, + tensorflow_schema_builder, + cpu_instance_type, + test_data, +): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"] + + model_artifacts_uri = "s3://{}/{}/{}/{}".format( + sagemaker_session.default_bucket(), + "model_builder_integ_test", + "mlflow", + "tensorflow", + ) + + model_path = S3Uploader.upload( + local_path=TENSORFLOW_MLFLOW_RESOURCE_DIR, + desired_s3_uri=model_artifacts_uri, + sagemaker_session=sagemaker_session, + ) + + model_builder = ModelBuilder( + mode=Mode.SAGEMAKER_ENDPOINT, + schema_builder=tensorflow_schema_builder, + role_arn=role_arn, + sagemaker_session=sagemaker_session, + model_metadata={"MLFLOW_MODEL_PATH": model_path}, + ) + + model = model_builder.build(sagemaker_session=sagemaker_session) + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + test_x, _ = test_data + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1) + logger.info("Endpoint successfully deployed.") + predictor.predict(test_x) + model_data_artifact = None + for artifact in Artifact.list( + source_uri=model_builder.s3_upload_path, sagemaker_session=sagemaker_session + ): + model_data_artifact = artifact + for association in Association.list( + destination_arn=model_data_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ): + assert ( + association.source_type == MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE + ) + break + + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + raise caught_ex diff --git a/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py b/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py index 5a73942afe..7b47440a97 100644 --- a/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_mlflow_xgboost_flavor_happy.py @@ -16,6 +16,8 @@ import io import numpy as np +from sagemaker.lineage.artifact import Artifact +from sagemaker.lineage.association import Association from sagemaker.s3 import S3Uploader from sagemaker.serve.builder.model_builder import ModelBuilder, Mode from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator @@ -32,6 +34,10 @@ from tests.integ.utils import cleanup_model_resources import logging +from sagemaker.serve.utils.lineage_constants import ( + MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, +) + logger = logging.getLogger(__name__) ROLE_NAME = "SageMakerRole" @@ -187,6 +193,19 @@ def test_happy_xgboost_sagemaker_endpoint_with_torch_serve( predictor = model.deploy(instance_type=cpu_instance_type, initial_instance_count=1) logger.info("Endpoint successfully deployed.") predictor.predict(test_x) + model_data_artifact = None + for artifact in Artifact.list( + source_uri=model_builder.s3_upload_path, sagemaker_session=sagemaker_session + ): + model_data_artifact = artifact + for association in Association.list( + destination_arn=model_data_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ): + assert ( + association.source_type == MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE + ) + break except Exception as e: caught_ex = e finally: @@ -196,9 +215,4 @@ def test_happy_xgboost_sagemaker_endpoint_with_torch_serve( endpoint_name=model.endpoint_name, ) if caught_ex: - logger.exception(caught_ex) - ignore_if_worker_dies = "Worker died." in str(caught_ex) - # https://github.com/pytorch/serve/issues/3032 - assert ( - ignore_if_worker_dies - ), f"{caught_ex} was thrown when running pytorch squeezenet sagemaker endpoint test" + raise caught_ex diff --git a/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py b/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py index 933c18bacf..cf1eb65325 100644 --- a/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py +++ b/tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py @@ -71,9 +71,12 @@ def model_input(): @pytest.fixture -def model_builder_model_schema_builder(): +def model_builder_model_schema_builder(sagemaker_session): return ModelBuilder( - model_path=HF_DIR, model=model_id, schema_builder=SchemaBuilder(sample_input, sample_output) + sagemaker_session=sagemaker_session, + model_path=HF_DIR, + model=model_id, + schema_builder=SchemaBuilder(sample_input, sample_output), ) @@ -93,6 +96,8 @@ def model_builder(request): def test_non_text_generation_model_single_GPU( sagemaker_session, model_builder, model_input, **kwargs ): + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") iam_client = sagemaker_session.boto_session.client("iam") role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session) @@ -144,6 +149,8 @@ def test_non_text_generation_model_single_GPU( def test_non_text_generation_model_multi_GPU( sagemaker_session, model_builder, model_input, **kwargs ): + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") iam_client = sagemaker_session.boto_session.client("iam") role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] caught_ex = None diff --git a/tests/integ/sagemaker/serve/test_serve_model_builder_handshake.py b/tests/integ/sagemaker/serve/test_serve_model_builder_handshake.py new file mode 100644 index 0000000000..d024e761a8 --- /dev/null +++ b/tests/integ/sagemaker/serve/test_serve_model_builder_handshake.py @@ -0,0 +1,208 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import os +import uuid + +import numpy as np +import pandas as pd +from sagemaker_core.main.resources import TrainingJob +from xgboost import XGBClassifier + +from sagemaker.serve import ModelBuilder, SchemaBuilder +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker_core.main.shapes import ( + OutputDataConfig, + StoppingCondition, + Channel, + DataSource, + S3DataSource, + AlgorithmSpecification, + ResourceConfig, +) +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split + +from sagemaker import get_execution_role, image_uris +from sagemaker.modules.train import ModelTrainer + +prefix = "DEMO-scikit-iris" +TRAIN_DATA = "train.csv" +TEST_DATA = "test.csv" +DATA_DIRECTORY = "data" + + +class XGBoostSpec(InferenceSpec): + def load(self, model_dir: str): + print(model_dir) + model = XGBClassifier() + model.load_model(model_dir + "/xgboost-model") + return model + + def invoke(self, input_object: object, model: object): + prediction_probabilities = model.predict_proba(input_object) + predictions = np.argmax(prediction_probabilities, axis=1) + return predictions + + +@pytest.fixture(scope="module") +def data_setup(mb_sagemaker_session): + sagemaker_session = mb_sagemaker_session + bucket = sagemaker_session.default_bucket() + + iris = load_iris() + iris_df = pd.DataFrame(iris.data, columns=iris.feature_names) + iris_df["target"] = iris.target + + os.makedirs("./data", exist_ok=True) + + iris_df = iris_df[["target"] + [col for col in iris_df.columns if col != "target"]] + + train_data, test_data = train_test_split(iris_df, test_size=0.2, random_state=42) + + train_data.to_csv("./data/train.csv", index=False, header=False) + test_data.to_csv("./data/test.csv", index=False, header=False) + + data = {"Name": ["Alice", "Bob", "Charlie"]} + df = pd.DataFrame(data) + schema_builder = SchemaBuilder(sample_input=df, sample_output=df) + + # Remove the target column from the testing data. We will use this to call invoke_endpoint later + test_data.drop("target", axis=1) + + sagemaker_session.upload_data( + DATA_DIRECTORY, bucket=bucket, key_prefix="{}/{}".format(prefix, DATA_DIRECTORY) + ) + + s3_input_path = "s3://{}/{}/data/{}".format(bucket, prefix, TRAIN_DATA) + s3_output_path = "s3://{}/{}/output".format(bucket, prefix) + + data_setup = { + "s3_input_path": s3_input_path, + "s3_output_path": s3_output_path, + "schema_builder": schema_builder, + } + return data_setup + + +def test_model_trainer_handshake(mb_sagemaker_session, mb_sagemaker_core_session, data_setup): + sagemaker_session = mb_sagemaker_session + role = get_execution_role(sagemaker_session=sagemaker_session) + xgboost_image = image_uris.retrieve( + framework="xgboost", region="us-west-2", image_scope="training" + ) + + model_trainer = ModelTrainer( + sagemaker_session=mb_sagemaker_core_session, + base_job_name="test-mb-handshake", + hyperparameters={ + "objective": "multi:softmax", + "num_class": "3", + "num_round": "10", + "eval_metric": "merror", + }, + training_image=xgboost_image, + training_input_mode="File", + role=role, + output_data_config=OutputDataConfig(s3_output_path=data_setup["s3_output_path"]), + stopping_condition=StoppingCondition(max_runtime_in_seconds=600), + ) + + model_trainer.train( + input_data_config=[ + Channel( + channel_name="train", + content_type="csv", + compression_type="None", + record_wrapper_type="None", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=data_setup["s3_input_path"], + s3_data_distribution_type="FullyReplicated", + ) + ), + ) + ] + ) + + model_builder = ModelBuilder( + model=model_trainer, # ModelTrainer object passed onto ModelBuilder directly + sagemaker_session=sagemaker_session, + role_arn=role, + image_uri=xgboost_image, + inference_spec=XGBoostSpec(), + schema_builder=data_setup["schema_builder"], + instance_type="ml.c6i.xlarge", + ) + model = model_builder.build() + assert model.model_data == model_trainer._latest_training_job.model_artifacts.s3_model_artifacts + + +def test_sagemaker_core_handshake(mb_sagemaker_session, data_setup): + sagemaker_session = mb_sagemaker_session + role = get_execution_role(sagemaker_session=sagemaker_session) + xgboost_image = image_uris.retrieve( + framework="xgboost", region="us-west-2", image_scope="training" + ) + + training_job_name = str(uuid.uuid4()) + training_job = TrainingJob.create( + training_job_name=training_job_name, + hyper_parameters={ + "objective": "multi:softmax", + "num_class": "3", + "num_round": "10", + "eval_metric": "merror", + }, + algorithm_specification=AlgorithmSpecification( + training_image=xgboost_image, training_input_mode="File" + ), + role_arn=role, + input_data_config=[ + Channel( + channel_name="train", + content_type="csv", + compression_type="None", + record_wrapper_type="None", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=data_setup["s3_input_path"], + s3_data_distribution_type="FullyReplicated", + ) + ), + ) + ], + output_data_config=OutputDataConfig(s3_output_path=data_setup["s3_output_path"]), + resource_config=ResourceConfig( + instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=30 + ), + stopping_condition=StoppingCondition(max_runtime_in_seconds=600), + ) + training_job.wait() + + model_builder = ModelBuilder( + sagemaker_session=sagemaker_session, + model=training_job, + role_arn=role, + inference_spec=XGBoostSpec(), + image_uri=xgboost_image, + schema_builder=data_setup["schema_builder"], + instance_type="ml.c6i.xlarge", + ) + model = model_builder.build() + + assert model.model_data == training_job.model_artifacts.s3_model_artifacts diff --git a/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py b/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py new file mode 100644 index 0000000000..7191de4e7d --- /dev/null +++ b/tests/integ/sagemaker/serve/test_serve_model_builder_inference_component_happy.py @@ -0,0 +1,150 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import tests.integ +import uuid + +from botocore.exceptions import ClientError +from sagemaker.predictor import Predictor +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +from sagemaker.utils import unique_name_from_base + +from tests.integ.sagemaker.serve.constants import ( + SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, +) +from tests.integ.timeout import timeout +import logging + +logger = logging.getLogger(__name__) + +sample_input = {"inputs": "What are falcons?", "parameters": {"max_new_tokens": 32}} + +sample_output = [ + { + "generated_text": "Falcons are small to medium-sized birds of prey related to hawks and eagles." + } +] + +LLAMA_2_7B_JS_ID = "meta-textgeneration-llama-2-7b" +LLAMA_IC_NAME = "llama2-mb-ic" +INSTANCE_TYPE = "ml.g5.24xlarge" + + +@pytest.fixture +def model_builder_llama_inference_component(): + return ModelBuilder( + model=LLAMA_2_7B_JS_ID, + schema_builder=SchemaBuilder(sample_input, sample_output), + resource_requirements=ResourceRequirements( + requests={"memory": 98304, "num_accelerators": 4, "copies": 1, "num_cpus": 40} + ), + ) + + +@pytest.mark.skipif( + tests.integ.test_region() not in "us-west-2", + reason="G5 capacity available in PDX.", +) +def test_model_builder_ic_sagemaker_endpoint( + sagemaker_session, + model_builder_llama_inference_component, +): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + + model_builder_llama_inference_component.sagemaker_session = sagemaker_session + model_builder_llama_inference_component.instance_type = INSTANCE_TYPE + + model_builder_llama_inference_component.inference_component_name = unique_name_from_base( + LLAMA_IC_NAME + ) + + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + chain = ModelBuilder( + modelbuilder_list=[ + model_builder_llama_inference_component, + ], + role_arn=role_arn, + sagemaker_session=sagemaker_session, + ) + + chain.build() + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + endpoint_name = f"llama-ic-endpoint-name-{uuid.uuid1().hex}" + predictors = chain.deploy( + instance_type=INSTANCE_TYPE, + initial_instance_count=1, + accept_eula=True, + endpoint_name=endpoint_name, + ) + logger.info("Inference components successfully deployed.") + predictors[0].predict(sample_input) + assert len(predictors) == 1 + except Exception as e: + caught_ex = e + finally: + if caught_ex: + logger.exception(caught_ex) + cleanup_resources(sagemaker_session, [LLAMA_IC_NAME]) + assert False, f"{caught_ex} thrown when running mb-IC deployment test." + + cleanup_resources(sagemaker_session, [LLAMA_IC_NAME]) + + +def cleanup_resources(sagemaker_session, ic_base_names): + sm_client = sagemaker_session.sagemaker_client + + endpoint_names = set() + for ic_base_name in ic_base_names: + response = sm_client.list_inference_components( + NameContains=ic_base_name, StatusEquals="InService" + ) + ics = response["InferenceComponents"] + + logger.info(f"Cleaning up {len(ics)} ICs with base name {ic_base_name}.") + for ic in ics: + ic_name = ic["InferenceComponentName"] + ep_name = ic["EndpointName"] + + try: + logger.info(f"Deleting IC with name {ic_name}") + Predictor( + endpoint_name=ep_name, + component_name=ic_name, + sagemaker_session=sagemaker_session, + ).delete_predictor() + sagemaker_session.wait_for_inference_component_deletion( + inference_component_name=ic_name, + poll=10, + ) + endpoint_names.add(ep_name) + except ClientError as e: + logger.warning(e) + + for endpoint_name in endpoint_names: + logger.info(f"Deleting endpoint with name {endpoint_name}") + try: + Predictor( + endpoint_name=endpoint_name, sagemaker_session=sagemaker_session + ).delete_endpoint() + except ClientError as e: + logger.warning(e) diff --git a/tests/integ/sagemaker/serve/test_serve_tei.py b/tests/integ/sagemaker/serve/test_serve_tei.py new file mode 100644 index 0000000000..4c824da401 --- /dev/null +++ b/tests/integ/sagemaker/serve/test_serve_tei.py @@ -0,0 +1,92 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from sagemaker.serve.builder.schema_builder import SchemaBuilder +from sagemaker.serve.builder.model_builder import ModelBuilder, Mode + +from tests.integ.sagemaker.serve.constants import ( + HF_DIR, + PYTHON_VERSION_IS_NOT_310, + SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, +) + +from tests.integ.timeout import timeout +from tests.integ.utils import cleanup_model_resources +import logging + +logger = logging.getLogger(__name__) + +sample_input = {"inputs": "What is Deep Learning?"} + +loaded_response = [] + + +@pytest.fixture +def model_input(): + return {"inputs": "What is Deep Learning?"} + + +@pytest.fixture +def model_builder_model_schema_builder(sagemaker_session): + return ModelBuilder( + sagemaker_session=sagemaker_session, + model_path=HF_DIR, + model="BAAI/bge-m3", + schema_builder=SchemaBuilder(sample_input, loaded_response), + env_vars={ + # Add this to bypass JumpStart model mapping + "HF_MODEL_ID": "BAAI/bge-m3" + }, + ) + + +@pytest.fixture +def model_builder(request): + return request.getfixturevalue(request.param) + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="Testing feature needs latest metadata", +) +@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True) +def test_tei_sagemaker_endpoint(sagemaker_session, model_builder, model_input): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + model = model_builder.build( + mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session + ) + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(instance_type="ml.g5.2xlarge", initial_instance_count=1) + predictor.predict(model_input) + assert predictor is not None + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + logger.exception(caught_ex) + assert False, f"{caught_ex} was thrown when running tei sagemaker endpoint test" diff --git a/tests/integ/sagemaker/serve/test_serve_transformers.py b/tests/integ/sagemaker/serve/test_serve_transformers.py index 64029f7290..9405934474 100644 --- a/tests/integ/sagemaker/serve/test_serve_transformers.py +++ b/tests/integ/sagemaker/serve/test_serve_transformers.py @@ -72,11 +72,12 @@ def model_input(): @pytest.fixture -def model_builder_model_schema_builder(): +def model_builder_model_schema_builder(sagemaker_session): return ModelBuilder( model_path=HF_DIR, model="bert-base-uncased", schema_builder=SchemaBuilder(sample_input, loaded_response), + sagemaker_session=sagemaker_session, ) @@ -96,6 +97,9 @@ def model_builder(request): def test_pytorch_transformers_sagemaker_endpoint( sagemaker_session, model_builder, model_input, **kwargs ): + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") + logger.info("Running in SAGEMAKER_ENDPOINT mode...") caught_ex = None @@ -127,4 +131,4 @@ def test_pytorch_transformers_sagemaker_endpoint( logger.exception(caught_ex) assert ( False - ), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test" + ), f"{caught_ex} thrown when running pytorch transformers sagemaker endpoint test" diff --git a/tests/integ/sagemaker/serve/utils/test_hardware_detector.py b/tests/integ/sagemaker/serve/utils/test_hardware_detector.py index 9102927c55..bab26a25d1 100644 --- a/tests/integ/sagemaker/serve/utils/test_hardware_detector.py +++ b/tests/integ/sagemaker/serve/utils/test_hardware_detector.py @@ -19,7 +19,7 @@ REGION = "us-west-2" VALID_INSTANCE_TYPE = "ml.g5.48xlarge" INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge" -EXPECTED_INSTANCE_GPU_INFO = (8, 196608) +EXPECTED_INSTANCE_GPU_INFO = (8, 183104) def test_get_gpu_info_success(sagemaker_session): diff --git a/tests/integ/sagemaker/workflow/helpers.py b/tests/integ/sagemaker/workflow/helpers.py index 20365ef169..9f0176c5c2 100644 --- a/tests/integ/sagemaker/workflow/helpers.py +++ b/tests/integ/sagemaker/workflow/helpers.py @@ -70,8 +70,8 @@ def create_and_execute_pipeline( assert execution_steps[0]["StepStatus"] == step_status if step_result_type: result = execution.result(execution_steps[0]["StepName"]) - assert ( - type(result) == step_result_type + assert isinstance( + result, step_result_type ), f"Expected {step_result_type}, instead found {type(result)}" if step_result_value: diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index 0733649cb2..e84c1920f4 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -18,12 +18,17 @@ # and the RegisterModel and CreateModelStep have been replaced with the new interface - ModelStep from __future__ import absolute_import +import json import logging import os import re import pytest +from packaging.version import Version + +from sagemaker.model_card.model_card import ModelCard, ModelOverview, ModelPackageModelCard +from sagemaker.model_card.schema_constraints import ModelCardStatusEnum import tests from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution from sagemaker.tensorflow import TensorFlow, TensorFlowModel @@ -43,6 +48,7 @@ from sagemaker.s3 import S3Uploader from sagemaker.sklearn import SKLearnModel, SKLearnProcessor from sagemaker.mxnet.model import MXNetModel +from sagemaker.model_life_cycle import ModelLifeCycle from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.pipeline import Pipeline @@ -56,6 +62,15 @@ ) from tests.integ.kms_utils import get_or_create_kms_key from tests.integ import DATA_DIR +from sagemaker.model_card import ( + IntendedUses, + BusinessDetails, + EvaluationJob, + AdditionalInformation, + Metric, + MetricGroup, + MetricTypeEnum, +) @pytest.fixture @@ -703,6 +718,594 @@ def test_model_registration_with_drift_check_baselines( pass +def test_model_registration_with_model_card_object( + sagemaker_session_for_pipeline, + role, + pipeline_name, +): + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + instance_type = "ml.m5.xlarge" + + # upload model data to s3 + model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz") + model_base_uri = "s3://{}/{}/input/model/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("model"), + ) + model_uri = S3Uploader.upload( + model_local_path, model_base_uri, sagemaker_session=sagemaker_session_for_pipeline + ) + model_uri_param = ParameterString(name="model_uri", default_value=model_uri) + + # upload metrics to s3 + metrics_data = ( + '{"regression_metrics": {"mse": {"value": 4.925353410353891, ' + '"standard_deviation": 2.219186917819692}}}' + ) + metrics_base_uri = "s3://{}/{}/input/metrics/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("metrics"), + ) + metrics_uri = S3Uploader.upload_string_as_file_body( + body=metrics_data, + desired_s3_uri=metrics_base_uri, + sagemaker_session=sagemaker_session_for_pipeline, + ) + metrics_uri_param = ParameterString(name="metrics_uri", default_value=metrics_uri) + + model_metrics = ModelMetrics( + bias=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + explainability=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_pre_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_post_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + ) + customer_metadata_properties = {"key1": "value1"} + domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + skip_model_validation = "All" + + # If image_uri is not provided, the instance_type should not be a pipeline variable + # since instance_type is used to retrieve image_uri in compile time (PySDK) + estimator = XGBoost( + entry_point="training.py", + source_dir=os.path.join(DATA_DIR, "sip"), + instance_type=instance_type, + instance_count=instance_count, + framework_version="0.90-2", + sagemaker_session=sagemaker_session_for_pipeline, + py_version="py3", + role=role, + ) + intended_uses = IntendedUses( + purpose_of_model="Test model card.", + intended_uses="Not used except this test.", + factors_affecting_model_efficiency="No.", + risk_rating="Low", + explanations_for_risk_rating="Just an example.", + ) + business_details = BusinessDetails( + business_problem="The business problem that your model is used to solve.", + business_stakeholders="The stakeholders who have the interest in the business that your model is used for.", + line_of_business="Services that the business is offering.", + ) + additional_information = AdditionalInformation( + ethical_considerations="Your model ethical consideration.", + caveats_and_recommendations="Your model's caveats and recommendations.", + custom_details={"custom details1": "details value"}, + ) + manual_metric_group = MetricGroup( + name="binary classification metrics", + metric_data=[Metric(name="accuracy", type=MetricTypeEnum.NUMBER, value=0.5)], + ) + example_evaluation_job = EvaluationJob( + name="Example evaluation job", + evaluation_observation="Evaluation observations.", + datasets=["s3://path/to/evaluation/data"], + metric_groups=[manual_metric_group], + ) + evaluation_details = [example_evaluation_job] + + model_overview = ModelOverview(model_creator="TestCreator") + + my_card = ModelCard( + name="TestName", + sagemaker_session=sagemaker_session_for_pipeline, + status=ModelCardStatusEnum.DRAFT, + model_overview=model_overview, + intended_uses=intended_uses, + business_details=business_details, + evaluation_details=evaluation_details, + additional_information=additional_information, + ) + + step_register = RegisterModel( + name="MyRegisterModelStep", + estimator=estimator, + model_data=model_uri_param, + content_types=["application/json"], + response_types=["application/json"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_group_name="testModelPackageGroup", + model_metrics=model_metrics, + customer_metadata_properties=customer_metadata_properties, + domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, + model_card=my_card, + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[ + model_uri_param, + metrics_uri_param, + instance_count, + ], + steps=[step_register], + sagemaker_session=sagemaker_session_for_pipeline, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + + for _ in retries( + max_retry_count=5, + exception_message_prefix="Waiting for a successful execution of pipeline", + seconds_to_sleep=10, + ): + execution = pipeline.start( + parameters={"model_uri": model_uri, "metrics_uri": metrics_uri} + ) + response = execution.describe() + + assert response["PipelineArn"] == create_arn + + wait_pipeline_execution(execution=execution) + execution_steps = execution.list_steps() + + assert len(execution_steps) == 1 + failure_reason = execution_steps[0].get("FailureReason", "") + if failure_reason != "": + logging.error( + f"Pipeline execution failed with error: {failure_reason}." " Retrying.." + ) + continue + assert execution_steps[0]["StepStatus"] == "Succeeded" + assert execution_steps[0]["StepName"] == "MyRegisterModelStep-RegisterModel" + + response = sagemaker_session_for_pipeline.sagemaker_client.describe_model_package( + ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"] + ) + + assert ( + response["ModelMetrics"]["Explainability"]["Report"]["ContentType"] + == "application/json" + ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties + assert response["Domain"] == domain + assert response["Task"] == task + assert response["SamplePayloadUrl"] == sample_payload_url + assert response["SkipModelValidation"] == skip_model_validation + assert (response["ModelCard"]["ModelCardStatus"]) == ModelCardStatusEnum.DRAFT + model_card_content = json.loads(response["ModelCard"]["ModelCardContent"]) + assert (model_card_content["model_overview"]["model_creator"]) == "TestCreator" + assert (model_card_content["intended_uses"]["purpose_of_model"]) == "Test model card." + assert ( + model_card_content["business_details"]["line_of_business"] + ) == "Services that the business is offering." + assert (model_card_content["evaluation_details"][0]["name"]) == "Example evaluation job" + + break + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_model_registration_with_model_life_cycle_object( + sagemaker_session_for_pipeline, + role, + pipeline_name, +): + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + instance_type = "ml.m5.xlarge" + + # upload model data to s3 + model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz") + model_base_uri = "s3://{}/{}/input/model/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("model"), + ) + model_uri = S3Uploader.upload( + model_local_path, model_base_uri, sagemaker_session=sagemaker_session_for_pipeline + ) + model_uri_param = ParameterString(name="model_uri", default_value=model_uri) + + # upload metrics to s3 + metrics_data = ( + '{"regression_metrics": {"mse": {"value": 4.925353410353891, ' + '"standard_deviation": 2.219186917819692}}}' + ) + metrics_base_uri = "s3://{}/{}/input/metrics/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("metrics"), + ) + metrics_uri = S3Uploader.upload_string_as_file_body( + body=metrics_data, + desired_s3_uri=metrics_base_uri, + sagemaker_session=sagemaker_session_for_pipeline, + ) + metrics_uri_param = ParameterString(name="metrics_uri", default_value=metrics_uri) + + model_metrics = ModelMetrics( + bias=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + explainability=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_pre_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_post_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + ) + customer_metadata_properties = {"key1": "value1"} + domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + skip_model_validation = "All" + + # If image_uri is not provided, the instance_type should not be a pipeline variable + # since instance_type is used to retrieve image_uri in compile time (PySDK) + estimator = XGBoost( + entry_point="training.py", + source_dir=os.path.join(DATA_DIR, "sip"), + instance_type=instance_type, + instance_count=instance_count, + framework_version="0.90-2", + sagemaker_session=sagemaker_session_for_pipeline, + py_version="py3", + role=role, + ) + create_model_life_cycle = ModelLifeCycle( + stage="Development", + stage_status="In-Progress", + stage_description="Development In Progress", + ) + + step_register = RegisterModel( + name="MyRegisterModelStep", + estimator=estimator, + model_data=model_uri_param, + content_types=["application/json"], + response_types=["application/json"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_group_name="testModelPackageGroup", + model_metrics=model_metrics, + customer_metadata_properties=customer_metadata_properties, + domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, + model_life_cycle=create_model_life_cycle, + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[ + model_uri_param, + metrics_uri_param, + instance_count, + ], + steps=[step_register], + sagemaker_session=sagemaker_session_for_pipeline, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + + for _ in retries( + max_retry_count=5, + exception_message_prefix="Waiting for a successful execution of pipeline", + seconds_to_sleep=10, + ): + execution = pipeline.start( + parameters={"model_uri": model_uri, "metrics_uri": metrics_uri} + ) + response = execution.describe() + + assert response["PipelineArn"] == create_arn + + wait_pipeline_execution(execution=execution) + execution_steps = execution.list_steps() + + assert len(execution_steps) == 1 + failure_reason = execution_steps[0].get("FailureReason", "") + if failure_reason != "": + logging.error( + f"Pipeline execution failed with error: {failure_reason}." " Retrying.." + ) + continue + assert execution_steps[0]["StepStatus"] == "Succeeded" + assert execution_steps[0]["StepName"] == "MyRegisterModelStep-RegisterModel" + + response = sagemaker_session_for_pipeline.sagemaker_client.describe_model_package( + ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"] + ) + + assert ( + response["ModelMetrics"]["Explainability"]["Report"]["ContentType"] + == "application/json" + ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties + assert response["Domain"] == domain + assert response["Task"] == task + assert response["SamplePayloadUrl"] == sample_payload_url + assert response["SkipModelValidation"] == skip_model_validation + assert (response["ModelLifeCycle"]["Stage"]) == "Development" + assert (response["ModelLifeCycle"]["StageStatus"]) == "In-Progress" + assert (response["ModelLifeCycle"]["StageDescription"]) == "Development In Progress" + break + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_model_registration_with_model_card_json( + sagemaker_session_for_pipeline, + role, + pipeline_name, +): + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + instance_type = "ml.m5.xlarge" + + # upload model data to s3 + model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz") + model_base_uri = "s3://{}/{}/input/model/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("model"), + ) + model_uri = S3Uploader.upload( + model_local_path, model_base_uri, sagemaker_session=sagemaker_session_for_pipeline + ) + model_uri_param = ParameterString(name="model_uri", default_value=model_uri) + + # upload metrics to s3 + metrics_data = ( + '{"regression_metrics": {"mse": {"value": 4.925353410353891, ' + '"standard_deviation": 2.219186917819692}}}' + ) + metrics_base_uri = "s3://{}/{}/input/metrics/{}".format( + sagemaker_session_for_pipeline.default_bucket(), + "register_model_test_with_drift_baseline", + utils.unique_name_from_base("metrics"), + ) + metrics_uri = S3Uploader.upload_string_as_file_body( + body=metrics_data, + desired_s3_uri=metrics_base_uri, + sagemaker_session=sagemaker_session_for_pipeline, + ) + metrics_uri_param = ParameterString(name="metrics_uri", default_value=metrics_uri) + + model_metrics = ModelMetrics( + bias=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + explainability=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_pre_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + bias_post_training=MetricsSource( + s3_uri=metrics_uri_param, + content_type="application/json", + ), + ) + customer_metadata_properties = {"key1": "value1"} + domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + skip_model_validation = "All" + + # If image_uri is not provided, the instance_type should not be a pipeline variable + # since instance_type is used to retrieve image_uri in compile time (PySDK) + estimator = XGBoost( + entry_point="training.py", + source_dir=os.path.join(DATA_DIR, "sip"), + instance_type=instance_type, + instance_count=instance_count, + framework_version="0.90-2", + sagemaker_session=sagemaker_session_for_pipeline, + py_version="py3", + role=role, + ) + + model_card_content = { + "model_overview": { + "model_creator": "TestCreator", + }, + "intended_uses": { + "purpose_of_model": "Test model card.", + "intended_uses": "Not used except this test.", + "factors_affecting_model_efficiency": "No.", + "risk_rating": "Low", + "explanations_for_risk_rating": "Just an example.", + }, + "business_details": { + "business_problem": "The business problem that your model is used to solve.", + "business_stakeholders": "The stakeholders who have the interest in the business.", + "line_of_business": "Services that the business is offering.", + }, + "evaluation_details": [ + { + "name": "Example evaluation job", + "evaluation_observation": "Evaluation observations.", + "metric_groups": [ + { + "name": "binary classification metrics", + "metric_data": [{"name": "accuracy", "type": "number", "value": 0.5}], + } + ], + } + ], + "additional_information": { + "ethical_considerations": "Your model ethical consideration.", + "caveats_and_recommendations": 'Your model"s caveats and recommendations.', + "custom_details": {"custom details1": "details value"}, + }, + } + my_card = ModelPackageModelCard( + model_card_status=ModelCardStatusEnum.DRAFT, model_card_content=model_card_content + ) + + step_register = RegisterModel( + name="MyRegisterModelStep", + estimator=estimator, + model_data=model_uri_param, + content_types=["application/json"], + response_types=["application/json"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_group_name="testModelPackageGroup", + model_metrics=model_metrics, + customer_metadata_properties=customer_metadata_properties, + domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, + model_card=my_card, + ) + + pipeline = Pipeline( + name=pipeline_name, + parameters=[ + model_uri_param, + metrics_uri_param, + instance_count, + ], + steps=[step_register], + sagemaker_session=sagemaker_session_for_pipeline, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + + for _ in retries( + max_retry_count=5, + exception_message_prefix="Waiting for a successful execution of pipeline", + seconds_to_sleep=10, + ): + execution = pipeline.start( + parameters={"model_uri": model_uri, "metrics_uri": metrics_uri} + ) + response = execution.describe() + + assert response["PipelineArn"] == create_arn + + wait_pipeline_execution(execution=execution) + execution_steps = execution.list_steps() + + assert len(execution_steps) == 1 + failure_reason = execution_steps[0].get("FailureReason", "") + if failure_reason != "": + logging.error( + f"Pipeline execution failed with error: {failure_reason}." " Retrying.." + ) + continue + assert execution_steps[0]["StepStatus"] == "Succeeded" + assert execution_steps[0]["StepName"] == "MyRegisterModelStep-RegisterModel" + + response = sagemaker_session_for_pipeline.sagemaker_client.describe_model_package( + ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"] + ) + + assert ( + response["ModelMetrics"]["Explainability"]["Report"]["ContentType"] + == "application/json" + ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties + assert response["Domain"] == domain + assert response["Task"] == task + assert response["SamplePayloadUrl"] == sample_payload_url + assert response["SkipModelValidation"] == skip_model_validation + assert (response["ModelCard"]["ModelCardStatus"]) == ModelCardStatusEnum.DRAFT + model_card_content = json.loads(response["ModelCard"]["ModelCardContent"]) + assert (model_card_content["model_overview"]["model_creator"]) == "TestCreator" + assert (model_card_content["intended_uses"]["purpose_of_model"]) == "Test model card." + assert ( + model_card_content["business_details"]["line_of_business"] + ) == "Services that the business is offering." + assert (model_card_content["evaluation_details"][0]["name"]) == "Example evaluation job" + + break + finally: + try: + pipeline.delete() + except Exception: + pass + + def test_model_registration_with_model_repack( sagemaker_session_for_pipeline, role, @@ -819,6 +1422,11 @@ def test_model_registration_with_tensorflow_model_with_pipeline_model( pipeline_name, region_name, ): + if Version(tf_full_version) >= Version("2.16"): + pytest.skip( + "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " + "https://github.com/tensorflow/io/issues/2039" + ) base_dir = os.path.join(DATA_DIR, "tensorflow_mnist") entry_point = os.path.join(base_dir, "mnist_v2.py") input_path = sagemaker_session_for_pipeline.upload_data( diff --git a/tests/integ/sagemaker/workflow/test_model_steps.py b/tests/integ/sagemaker/workflow/test_model_steps.py index da63bca597..02f7613f85 100644 --- a/tests/integ/sagemaker/workflow/test_model_steps.py +++ b/tests/integ/sagemaker/workflow/test_model_steps.py @@ -17,6 +17,8 @@ import pytest +from packaging.version import Version + from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution from sagemaker.workflow.fail_step import FailStep from sagemaker.workflow.functions import Join @@ -589,6 +591,11 @@ def test_model_registration_with_drift_check_baselines_and_model_metrics( def test_model_registration_with_tensorflow_model_with_pipeline_model( pipeline_session, role, tf_full_version, tf_full_py_version, pipeline_name ): + if Version(tf_full_version) >= Version("2.16"): + pytest.skip( + "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " + "https://github.com/tensorflow/io/issues/2039" + ) base_dir = os.path.join(DATA_DIR, "tensorflow_mnist") entry_point = os.path.join(base_dir, "mnist_v2.py") input_path = pipeline_session.upload_data( diff --git a/tests/integ/sagemaker/workflow/test_retry.py b/tests/integ/sagemaker/workflow/test_retry.py index 31c9859d50..9960dceed4 100644 --- a/tests/integ/sagemaker/workflow/test_retry.py +++ b/tests/integ/sagemaker/workflow/test_retry.py @@ -148,6 +148,8 @@ def test_model_registration_with_model_repack( role, pipeline_name, region_name, + pytorch_training_latest_version, + pytorch_training_latest_py_version, ): base_dir = os.path.join(DATA_DIR, "pytorch_mnist") entry_point = os.path.join(base_dir, "mnist.py") @@ -166,8 +168,8 @@ def test_model_registration_with_model_repack( pytorch_estimator = PyTorch( entry_point=entry_point, role=role, - framework_version="1.5.0", - py_version="py3", + framework_version=pytorch_training_latest_version, + py_version=pytorch_training_latest_py_version, instance_count=instance_count, instance_type=instance_type, sagemaker_session=pipeline_session, diff --git a/tests/integ/sagemaker/workflow/test_training_steps.py b/tests/integ/sagemaker/workflow/test_training_steps.py index 181167ab31..4b442c6d93 100644 --- a/tests/integ/sagemaker/workflow/test_training_steps.py +++ b/tests/integ/sagemaker/workflow/test_training_steps.py @@ -18,6 +18,8 @@ import pytest +from packaging.version import Version + from tests.integ.sagemaker.workflow.helpers import wait_pipeline_execution from sagemaker import TrainingInput, get_execution_role, utils, image_uris from sagemaker.debugger import ( @@ -235,6 +237,12 @@ def test_training_step_with_output_path_as_join( def test_tensorflow_training_step_with_parameterized_code_input( pipeline_session, role, tf_full_version, tf_full_py_version, pipeline_name ): + if Version(tf_full_version) >= Version("2.16"): + pytest.skip( + "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " + "https://github.com/tensorflow/io/issues/2039" + ) + base_dir = os.path.join(DATA_DIR, "tensorflow_mnist") entry_point1 = "mnist_v2.py" entry_point2 = "mnist_dummy.py" diff --git a/tests/integ/sagemaker/workflow/test_workflow.py b/tests/integ/sagemaker/workflow/test_workflow.py index 2643a3b88e..a879ff88e5 100644 --- a/tests/integ/sagemaker/workflow/test_workflow.py +++ b/tests/integ/sagemaker/workflow/test_workflow.py @@ -312,6 +312,7 @@ def test_three_step_definition( rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn, ) + assert pipeline.latest_pipeline_version_id == 1 finally: try: pipeline.delete() @@ -937,7 +938,6 @@ def test_large_pipeline(sagemaker_session_for_pipeline, role, pipeline_name, reg rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn, ) - response = pipeline.describe() assert len(json.loads(pipeline.describe()["PipelineDefinition"])["Steps"]) == 2000 pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)] @@ -1122,8 +1122,8 @@ def test_model_registration_with_tuning_model( entry_point=entry_point, source_dir=base_dir, role=role, - framework_version="1.10", - py_version="py38", + framework_version="1.13.1", + py_version="py39", instance_count=instance_count, instance_type=instance_type, sagemaker_session=pipeline_session, @@ -1159,8 +1159,8 @@ def test_model_registration_with_tuning_model( ), entry_point=entry_point, source_dir=base_dir, - framework_version="1.10", - py_version="py38", + framework_version="1.13.1", + py_version="py39", sagemaker_session=pipeline_session, ) step_model_regis_args = model.register( @@ -1387,3 +1387,56 @@ def test_caching_behavior( except Exception: os.remove(script_dir + "/dummy_script.py") pass + + +def test_pipeline_versioning(pipeline_session, role, pipeline_name, script_dir): + sklearn_train = SKLearn( + framework_version="0.20.0", + entry_point=os.path.join(script_dir, "train.py"), + instance_type="ml.m5.xlarge", + sagemaker_session=pipeline_session, + role=role, + ) + + step1 = TrainingStep( + name="my-train-1", + display_name="TrainingStep", + description="description for Training step", + step_args=sklearn_train.fit(), + ) + + step2 = TrainingStep( + name="my-train-2", + display_name="TrainingStep", + description="description for Training step", + step_args=sklearn_train.fit(), + ) + pipeline = Pipeline( + name=pipeline_name, + steps=[step1], + sagemaker_session=pipeline_session, + ) + + try: + pipeline.create(role) + + assert pipeline.latest_pipeline_version_id == 1 + + describe_response = pipeline.describe(pipeline_version_id=1) + assert len(json.loads(describe_response["PipelineDefinition"])["Steps"]) == 1 + + pipeline.steps.append(step2) + pipeline.upsert(role) + + assert pipeline.latest_pipeline_version_id == 2 + + describe_response = pipeline.describe(pipeline_version_id=2) + assert len(json.loads(describe_response["PipelineDefinition"])["Steps"]) == 2 + + assert len(pipeline.list_pipeline_versions()["PipelineVersionSummaries"]) == 2 + + finally: + try: + pipeline.delete() + except Exception: + pass diff --git a/tests/integ/test_byo_estimator.py b/tests/integ/test_byo_estimator.py index a504b974a9..386db8e14b 100644 --- a/tests/integ/test_byo_estimator.py +++ b/tests/integ/test_byo_estimator.py @@ -12,14 +12,20 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import io import json import os +import numpy as np + import pytest +import sagemaker.amazon.common as smac + import sagemaker from sagemaker import image_uris from sagemaker.estimator import Estimator +from sagemaker.s3 import S3Uploader from sagemaker.serializers import SimpleBaseSerializer from sagemaker.utils import unique_name_from_base from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, datasets @@ -102,6 +108,60 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se assert prediction["score"] is not None +@pytest.mark.release +def test_estimator_register_publish_training_details(sagemaker_session, region, cpu_instance_type): + + bucket = sagemaker_session.default_bucket() + prefix = "model-card-sample-notebook" + + raw_data = ( + (0.5, 0), + (0.75, 0), + (1.0, 0), + (1.25, 0), + (1.50, 0), + (1.75, 0), + (2.0, 0), + (2.25, 1), + (2.5, 0), + (2.75, 1), + (3.0, 0), + (3.25, 1), + (3.5, 0), + (4.0, 1), + (4.25, 1), + (4.5, 1), + (4.75, 1), + (5.0, 1), + (5.5, 1), + ) + training_data = np.array(raw_data).astype("float32") + labels = training_data[:, 1] + + # upload data to S3 bucket + buf = io.BytesIO() + smac.write_numpy_to_dense_tensor(buf, training_data, labels) + buf.seek(0) + s3_train_data = f"s3://{bucket}/{prefix}/train" + S3Uploader.upload_bytes(b=buf, s3_uri=s3_train_data, sagemaker_session=sagemaker_session) + output_location = f"s3://{bucket}/{prefix}/output" + container = image_uris.retrieve("linear-learner", region) + estimator = Estimator( + container, + role="SageMakerRole", + instance_count=1, + instance_type=cpu_instance_type, + output_path=output_location, + sagemaker_session=sagemaker_session, + ) + estimator.set_hyperparameters( + feature_dim=2, mini_batch_size=10, predictor_type="binary_classifier" + ) + estimator.fit({"train": s3_train_data}) + print(f"Training job name: {estimator.latest_training_job.name}") + estimator.register() + + def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, training_set): image_uri = image_uris.retrieve("factorization-machines", region) endpoint_name = unique_name_from_base("byo") diff --git a/tests/integ/test_collection.py b/tests/integ/test_collection.py index 2ee1d90e34..9a6db645cf 100644 --- a/tests/integ/test_collection.py +++ b/tests/integ/test_collection.py @@ -19,20 +19,22 @@ def test_create_collection_root_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") - collection.create(collection_name) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - assert collection_details["ResponseMetadata"]["HTTPStatusCode"] == 200 - delete_response = collection.delete([collection_name]) - assert len(delete_response["deleted_collections"]) == 1 - assert len(delete_response["delete_collection_failures"]) == 0 + try: + collection.create(collection_name) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + assert collection_details["ResponseMetadata"]["HTTPStatusCode"] == 200 + finally: + delete_response = collection.delete([collection_name]) + assert len(delete_response["deleted_collections"]) == 1 + assert len(delete_response["delete_collection_failures"]) == 0 def test_create_collection_nested_success(sagemaker_session): @@ -41,25 +43,27 @@ def test_create_collection_nested_success(sagemaker_session): child_collection_name = unique_name_from_base("test-collection-2") collection.create(collection_name) collection.create(collection_name=child_collection_name, parent_collection_name=collection_name) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - # has one child i.e child collection - assert len(collection_details["Resources"]) == 1 - - collection_details = sagemaker_session.list_group_resources( - group=child_collection_name, filters=collection_filter - ) - collection_details["ResponseMetadata"]["HTTPStatusCode"] - delete_response = collection.delete([child_collection_name, collection_name]) - assert len(delete_response["deleted_collections"]) == 2 - assert len(delete_response["delete_collection_failures"]) == 0 + try: + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + # has one child i.e child collection + assert len(collection_details["Resources"]) == 1 + + collection_details = sagemaker_session.list_group_resources( + group=child_collection_name, filters=collection_filter + ) + collection_details["ResponseMetadata"]["HTTPStatusCode"] + finally: + delete_response = collection.delete([child_collection_name, collection_name]) + assert len(delete_response["deleted_collections"]) == 2 + assert len(delete_response["delete_collection_failures"]) == 0 def test_add_remove_model_groups_in_collection_success(sagemaker_session): @@ -70,40 +74,42 @@ def test_add_remove_model_groups_in_collection_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") collection.create(collection_name) - model_groups = [] - model_groups.append(model_group_name) - add_response = collection.add_model_groups( - collection_name=collection_name, model_groups=model_groups - ) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - - assert len(add_response["failure"]) == 0 - assert len(add_response["added_groups"]) == 1 - assert len(collection_details["Resources"]) == 1 - - remove_response = collection.remove_model_groups( - collection_name=collection_name, model_groups=model_groups - ) - collection_details = sagemaker_session.list_group_resources( - group=collection_name, filters=collection_filter - ) - assert len(remove_response["failure"]) == 0 - assert len(remove_response["removed_groups"]) == 1 - assert len(collection_details["Resources"]) == 0 - - delete_response = collection.delete([collection_name]) - assert len(delete_response["deleted_collections"]) == 1 - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + add_response = collection.add_model_groups( + collection_name=collection_name, model_groups=model_groups + ) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + + assert len(add_response["failure"]) == 0 + assert len(add_response["added_groups"]) == 1 + assert len(collection_details["Resources"]) == 1 + + remove_response = collection.remove_model_groups( + collection_name=collection_name, model_groups=model_groups + ) + collection_details = sagemaker_session.list_group_resources( + group=collection_name, filters=collection_filter + ) + assert len(remove_response["failure"]) == 0 + assert len(remove_response["removed_groups"]) == 1 + assert len(collection_details["Resources"]) == 0 + + finally: + delete_response = collection.delete([collection_name]) + assert len(delete_response["deleted_collections"]) == 1 + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) def test_move_model_groups_in_collection_success(sagemaker_session): @@ -116,56 +122,58 @@ def test_move_model_groups_in_collection_success(sagemaker_session): destination_collection_name = unique_name_from_base("test-collection-destination") collection.create(source_collection_name) collection.create(destination_collection_name) - model_groups = [] - model_groups.append(model_group_name) - add_response = collection.add_model_groups( - collection_name=source_collection_name, model_groups=model_groups - ) - collection_filter = [ - { - "Name": "resource-type", - "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], - }, - ] - collection_details = sagemaker_session.list_group_resources( - group=source_collection_name, filters=collection_filter - ) - - assert len(add_response["failure"]) == 0 - assert len(add_response["added_groups"]) == 1 - assert len(collection_details["Resources"]) == 1 - - move_response = collection.move_model_group( - source_collection_name=source_collection_name, - model_group=model_group_name, - destination_collection_name=destination_collection_name, - ) - - assert move_response["moved_success"] == model_group_name - - collection_details = sagemaker_session.list_group_resources( - group=destination_collection_name, filters=collection_filter - ) - - assert len(collection_details["Resources"]) == 1 - - collection_details = sagemaker_session.list_group_resources( - group=source_collection_name, filters=collection_filter - ) - assert len(collection_details["Resources"]) == 0 - - remove_response = collection.remove_model_groups( - collection_name=destination_collection_name, model_groups=model_groups - ) - - assert len(remove_response["failure"]) == 0 - assert len(remove_response["removed_groups"]) == 1 - - delete_response = collection.delete([source_collection_name, destination_collection_name]) - assert len(delete_response["deleted_collections"]) == 2 - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + add_response = collection.add_model_groups( + collection_name=source_collection_name, model_groups=model_groups + ) + collection_filter = [ + { + "Name": "resource-type", + "Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"], + }, + ] + collection_details = sagemaker_session.list_group_resources( + group=source_collection_name, filters=collection_filter + ) + + assert len(add_response["failure"]) == 0 + assert len(add_response["added_groups"]) == 1 + assert len(collection_details["Resources"]) == 1 + + move_response = collection.move_model_group( + source_collection_name=source_collection_name, + model_group=model_group_name, + destination_collection_name=destination_collection_name, + ) + + assert move_response["moved_success"] == model_group_name + + collection_details = sagemaker_session.list_group_resources( + group=destination_collection_name, filters=collection_filter + ) + + assert len(collection_details["Resources"]) == 1 + + collection_details = sagemaker_session.list_group_resources( + group=source_collection_name, filters=collection_filter + ) + assert len(collection_details["Resources"]) == 0 + + remove_response = collection.remove_model_groups( + collection_name=destination_collection_name, model_groups=model_groups + ) + + assert len(remove_response["failure"]) == 0 + assert len(remove_response["removed_groups"]) == 1 + + finally: + delete_response = collection.delete([source_collection_name, destination_collection_name]) + assert len(delete_response["deleted_collections"]) == 2 + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) def test_list_collection_success(sagemaker_session): @@ -176,23 +184,27 @@ def test_list_collection_success(sagemaker_session): collection = Collection(sagemaker_session) collection_name = unique_name_from_base("test-collection") collection.create(collection_name) - model_groups = [] - model_groups.append(model_group_name) - collection.add_model_groups(collection_name=collection_name, model_groups=model_groups) - child_collection_name = unique_name_from_base("test-collection") - collection.create(parent_collection_name=collection_name, collection_name=child_collection_name) - root_collections = collection.list_collection() - is_collection_found = False - for root_collection in root_collections: - if root_collection["Name"] == collection_name: - is_collection_found = True - assert is_collection_found - - collection_content = collection.list_collection(collection_name) - assert len(collection_content) == 2 - - collection.remove_model_groups(collection_name=collection_name, model_groups=model_groups) - collection.delete([child_collection_name, collection_name]) - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_group_name - ) + try: + model_groups = [] + model_groups.append(model_group_name) + collection.add_model_groups(collection_name=collection_name, model_groups=model_groups) + child_collection_name = unique_name_from_base("test-collection") + collection.create( + parent_collection_name=collection_name, collection_name=child_collection_name + ) + root_collections = collection.list_collection() + is_collection_found = False + for root_collection in root_collections: + if root_collection["Name"] == collection_name: + is_collection_found = True + assert is_collection_found + + collection_content = collection.list_collection(collection_name) + assert len(collection_content) == 2 + + collection.remove_model_groups(collection_name=collection_name, model_groups=model_groups) + finally: + collection.delete([child_collection_name, collection_name]) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_group_name + ) diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index 43db78527a..75f1807148 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -1645,9 +1645,11 @@ def test_create_dataset_with_feature_group_base( feature_store_session, feature_group, offline_store_s3_uri ) - with timeout(minutes=10) and cleanup_offline_store( - base, feature_store_session - ) and cleanup_offline_store(feature_group, feature_store_session): + with ( + timeout(minutes=10) + and cleanup_offline_store(base, feature_store_session) + and cleanup_offline_store(feature_group, feature_store_session) + ): feature_store = FeatureStore(sagemaker_session=feature_store_session) df, query_string = ( feature_store.create_dataset(base=base, output_path=offline_store_s3_uri) @@ -1832,9 +1834,11 @@ def test_create_dataset_with_feature_group_base_with_additional_params( feature_store_session, feature_group, offline_store_s3_uri ) - with timeout(minutes=10) and cleanup_offline_store( - base, feature_store_session - ) and cleanup_offline_store(feature_group, feature_store_session): + with ( + timeout(minutes=10) + and cleanup_offline_store(base, feature_store_session) + and cleanup_offline_store(feature_group, feature_store_session) + ): feature_store = FeatureStore(sagemaker_session=feature_store_session) df, query_string = ( feature_store.create_dataset(base=base, output_path=offline_store_s3_uri) diff --git a/tests/integ/test_horovod.py b/tests/integ/test_horovod.py index 2ddcdc92e0..78314c2ade 100644 --- a/tests/integ/test_horovod.py +++ b/tests/integ/test_horovod.py @@ -62,11 +62,8 @@ def test_hvd_gpu( tmpdir, **kwargs, ): - if ( - Version(tensorflow_training_latest_version) >= Version("2.12") - and kwargs["instance_type"] == "ml.p2.xlarge" - ): - pytest.skip("P2 instances have been deprecated for sagemaker jobs starting TensorFlow 2.12") + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") if Version(tensorflow_training_latest_version) >= Version("2.13"): pytest.skip("Horovod is deprecated in TensorFlow 2.13 and above") diff --git a/tests/integ/test_horovod_mx.py b/tests/integ/test_horovod_mx.py index 7bd6a641e0..a238966dd3 100644 --- a/tests/integ/test_horovod_mx.py +++ b/tests/integ/test_horovod_mx.py @@ -58,6 +58,9 @@ def test_hvd_gpu( tmpdir, **kwargs, ): + if kwargs["instance_type"] == "ml.p2.xlarge": + pytest.skip("Instance type ml.p2.xlarge has been deprecated") + _create_and_fit_estimator( mxnet_training_latest_version, mxnet_training_latest_py_version, diff --git a/tests/integ/test_huggingface.py b/tests/integ/test_huggingface.py index a8be54c4d4..9098d8359a 100644 --- a/tests/integ/test_huggingface.py +++ b/tests/integ/test_huggingface.py @@ -29,6 +29,10 @@ @pytest.mark.release +@pytest.mark.skipif( + tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS, + reason="No P3 instances or low capacity in this region", +) def test_framework_processing_job_with_deps( sagemaker_session, huggingface_training_latest_version, @@ -59,6 +63,10 @@ def test_framework_processing_job_with_deps( @pytest.mark.release +@pytest.mark.skipif( + tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS, + reason="No P3 instances or low capacity in this region", +) def test_huggingface_training( sagemaker_session, huggingface_training_latest_version, diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index 9e6b41d753..6504932a7e 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -150,6 +150,40 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type): assert "Could not find model" in str(exception.value) +@pytest.mark.release +def test_inference_pipeline_model_register(sagemaker_session): + sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model") + endpoint_name = unique_name_from_base("test-inference-pipeline-deploy") + sparkml_model_data = sagemaker_session.upload_data( + path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"), + key_prefix="integ-test-data/sparkml/model", + ) + + sparkml_model = SparkMLModel( + model_data=sparkml_model_data, + env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA}, + sagemaker_session=sagemaker_session, + ) + + model = PipelineModel( + models=[sparkml_model], + role="SageMakerRole", + sagemaker_session=sagemaker_session, + name=endpoint_name, + ) + model_package_group_name = unique_name_from_base("pipeline-model-package") + model_package = model.register(model_package_group_name=model_package_group_name) + assert model_package.model_package_arn is not None + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package.model_package_arn + ) + + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_package_group_name + ) + + @pytest.mark.slow_test @pytest.mark.flaky(reruns=5, reruns_delay=2) def test_inference_pipeline_model_deploy_and_update_endpoint( diff --git a/tests/integ/test_model_package.py b/tests/integ/test_model_package.py index 914c5db7ed..1ac8e33fd8 100644 --- a/tests/integ/test_model_package.py +++ b/tests/integ/test_model_package.py @@ -12,14 +12,24 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import json import os -from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum +from sagemaker.model_card.model_card import ( + AdditionalInformation, + BusinessDetails, + IntendedUses, + ModelCard, + ModelOverview, + ModelPackageModelCard, +) +from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum, ModelCardStatusEnum from sagemaker.utils import unique_name_from_base from tests.integ import DATA_DIR from sagemaker.xgboost import XGBoostModel from sagemaker import image_uris from sagemaker.session import get_execution_role from sagemaker.model import ModelPackage +from sagemaker.model_life_cycle import ModelLifeCycle _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") @@ -66,6 +76,62 @@ def test_update_approval_model_package(sagemaker_session): ) +def test_update_model_life_cycle_model_package(sagemaker_session): + + model_group_name = unique_name_from_base("test-model-group") + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + create_model_life_cycle = ModelLifeCycle( + stage="Development", + stage_status="In-Progress", + stage_description="Development In Progress", + ) + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + model_life_cycle=create_model_life_cycle, + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + create_model_life_cycle_req = create_model_life_cycle._to_request_dict() + + assert desc_model_package["ModelLifeCycle"] == create_model_life_cycle_req + + update_model_life_cycle = ModelLifeCycle( + stage="Staging", + stage_status="In-Progress", + stage_description="Sending for Staging Verification", + ) + update_model_life_cycle_req = update_model_life_cycle._to_request_dict() + + model_package.update_model_life_cycle(update_model_life_cycle_req) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + assert desc_model_package["ModelLifeCycle"] == update_model_life_cycle_req + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package.model_package_arn + ) + + def test_inference_specification_addition(sagemaker_session): model_group_name = unique_name_from_base("test-model-group") @@ -183,6 +249,216 @@ def test_update_source_uri(sagemaker_session): assert desc_model_package["SourceUri"] == source_uri +def test_update_model_card_with_model_card_object(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + intended_uses = IntendedUses( + purpose_of_model="Test model card.", + intended_uses="Not used except this test.", + factors_affecting_model_efficiency="No.", + risk_rating="Low", + explanations_for_risk_rating="Just an example.", + ) + business_details = BusinessDetails( + business_problem="The business problem that your model is used to solve.", + business_stakeholders="The stakeholders who have the interest in the business that your model is used for.", + line_of_business="Services that the business is offering.", + ) + additional_information = AdditionalInformation( + ethical_considerations="Your model ethical consideration.", + caveats_and_recommendations="Your model's caveats and recommendations.", + custom_details={"custom details1": "details value"}, + ) + + model_overview = ModelOverview(model_creator="TestCreator") + + my_card = ModelCard( + name="TestName", + sagemaker_session=sagemaker_session, + status=ModelCardStatusEnum.DRAFT, + model_overview=model_overview, + intended_uses=intended_uses, + business_details=business_details, + additional_information=additional_information, + ) + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + model_card=my_card, + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + updated_model_overview = ModelOverview(model_creator="updatedCreator") + updated_intended_uses = IntendedUses( + purpose_of_model="Updated Test model card.", + ) + updated_my_card = ModelCard( + name="TestName", + sagemaker_session=sagemaker_session, + model_overview=updated_model_overview, + intended_uses=updated_intended_uses, + ) + model_package.update_model_card(updated_my_card) + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"]) + assert model_card_content["intended_uses"]["purpose_of_model"] == "Updated Test model card." + assert model_card_content["model_overview"]["model_creator"] == "updatedCreator" + updated_my_card_status = ModelCard( + name="TestName", + sagemaker_session=sagemaker_session, + status=ModelCardStatusEnum.PENDING_REVIEW, + ) + model_package.update_model_card(updated_my_card_status) + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"]) + assert desc_model_package["ModelCard"]["ModelCardStatus"] == ModelCardStatusEnum.PENDING_REVIEW + + +def test_update_model_card_with_model_card_json(sagemaker_session): + model_group_name = unique_name_from_base("test-model-group") + model_card_content = { + "model_overview": { + "model_creator": "TestCreator", + }, + "intended_uses": { + "purpose_of_model": "Test model card.", + "intended_uses": "Not used except this test.", + "factors_affecting_model_efficiency": "No.", + "risk_rating": "Low", + "explanations_for_risk_rating": "Just an example.", + }, + "business_details": { + "business_problem": "The business problem that your model is used to solve.", + "business_stakeholders": "The stakeholders who have the interest in the business.", + "line_of_business": "Services that the business is offering.", + }, + "evaluation_details": [ + { + "name": "Example evaluation job", + "evaluation_observation": "Evaluation observations.", + "metric_groups": [ + { + "name": "binary classification metrics", + "metric_data": [{"name": "accuracy", "type": "number", "value": 0.5}], + } + ], + } + ], + "additional_information": { + "ethical_considerations": "Your model ethical consideration.", + "caveats_and_recommendations": 'Your model"s caveats and recommendations.', + "custom_details": {"custom details1": "details value"}, + }, + } + my_card = ModelPackageModelCard( + model_card_status=ModelCardStatusEnum.DRAFT, model_card_content=model_card_content + ) + + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_group_name + ) + + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_XGBOOST_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + model = XGBoostModel( + model_data=xgb_model_data_s3, framework_version="1.3-1", sagemaker_session=sagemaker_session + ) + + model_package = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_group_name, + model_card=my_card, + ) + + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + updated_model_card_content = { + "model_overview": { + "model_creator": "updatedCreator", + }, + "intended_uses": { + "purpose_of_model": "Updated Test model card.", + "intended_uses": "Not used except this test.", + "factors_affecting_model_efficiency": "No.", + "risk_rating": "Low", + "explanations_for_risk_rating": "Just an example.", + }, + "business_details": { + "business_problem": "The business problem that your model is used to solve.", + "business_stakeholders": "The stakeholders who have the interest in the business.", + "line_of_business": "Services that the business is offering.", + }, + "evaluation_details": [ + { + "name": "Example evaluation job", + "evaluation_observation": "Evaluation observations.", + "metric_groups": [ + { + "name": "binary classification metrics", + "metric_data": [{"name": "accuracy", "type": "number", "value": 0.5}], + } + ], + } + ], + "additional_information": { + "ethical_considerations": "Your model ethical consideration.", + "caveats_and_recommendations": 'Your model"s caveats and recommendations.', + "custom_details": {"custom details1": "details value"}, + }, + } + updated_my_card = ModelPackageModelCard( + model_card_status=ModelCardStatusEnum.DRAFT, model_card_content=updated_model_card_content + ) + model_package.update_model_card(updated_my_card) + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + model_card_content = json.loads(desc_model_package["ModelCard"]["ModelCardContent"]) + assert model_card_content["intended_uses"]["purpose_of_model"] == "Updated Test model card." + assert model_card_content["model_overview"]["model_creator"] == "updatedCreator" + updated_my_card_status = ModelPackageModelCard( + model_card_status=ModelCardStatusEnum.PENDING_REVIEW, + ) + model_package.update_model_card(updated_my_card_status) + desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( + ModelPackageName=model_package.model_package_arn + ) + + assert desc_model_package["ModelCard"]["ModelCardStatus"] == ModelCardStatusEnum.PENDING_REVIEW + + def test_clone_model_package_using_source_uri(sagemaker_session): model_group_name = unique_name_from_base("test-model-group") diff --git a/tests/integ/test_multidatamodel.py b/tests/integ/test_multidatamodel.py index 59c79f5a9c..4c926a1c0e 100644 --- a/tests/integ/test_multidatamodel.py +++ b/tests/integ/test_multidatamodel.py @@ -14,6 +14,7 @@ import base64 import os +import time import requests import docker @@ -138,6 +139,7 @@ def test_multi_data_model_deploy_pretrained_models( multi_data_model.add_model(pretrained_model_data_local_path, PRETRAINED_MODEL_PATH_1) # Deploy model to an endpoint multi_data_model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name) + time.sleep(30) # Add models after deploy multi_data_model.add_model(pretrained_model_data_local_path, PRETRAINED_MODEL_PATH_2) @@ -266,6 +268,7 @@ def test_multi_data_model_deploy_trained_model_from_framework_estimator( multi_data_model.add_model(mxnet_model_1.model_data, PRETRAINED_MODEL_PATH_1) # Deploy model to an endpoint multi_data_model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name) + time.sleep(30) # Train another model mxnet_model_2 = _mxnet_training_job( @@ -373,6 +376,7 @@ def test_multi_data_model_deploy_train_model_from_amazon_first_party_estimator( multi_data_model.add_model(rcf_model_v1.model_data, PRETRAINED_MODEL_PATH_1) # Deploy model to an endpoint multi_data_model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name) + time.sleep(30) # Train another model rcf_model_v2 = __rcf_training_job( sagemaker_session, container_image, cpu_instance_type, 70, 20 @@ -470,6 +474,7 @@ def test_multi_data_model_deploy_pretrained_models_update_endpoint( multi_data_model.add_model(pretrained_model_data_local_path, PRETRAINED_MODEL_PATH_1) # Deploy model to an endpoint multi_data_model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name) + time.sleep(30) # Add model after deploy multi_data_model.add_model(pretrained_model_data_local_path, PRETRAINED_MODEL_PATH_2) diff --git a/tests/integ/test_processing.py b/tests/integ/test_processing.py index 8ceb3f2195..3be778ba84 100644 --- a/tests/integ/test_processing.py +++ b/tests/integ/test_processing.py @@ -20,7 +20,6 @@ from sagemaker import image_uris, Session from sagemaker.dataset_definition.inputs import ( DatasetDefinition, - RedshiftDatasetDefinition, AthenaDatasetDefinition, S3Input, ) @@ -766,25 +765,6 @@ def _get_processing_inputs_with_all_parameters(bucket): s3_compression_type="None", ), ), - ProcessingInput( - input_name="redshift_dataset_definition", - app_managed=True, - dataset_definition=DatasetDefinition( - local_path="/opt/ml/processing/input/rdd", - data_distribution_type="FullyReplicated", - input_mode="File", - redshift_dataset_definition=RedshiftDatasetDefinition( - cluster_id="integ-test-cluster-prod-us-west-2", - database="dev", - db_user="awsuser", - query_string="SELECT * FROM shoes", - cluster_role_arn="arn:aws:iam::037210630505:role/RedshiftClusterRole-prod-us-west-2", - output_s3_uri=f"s3://{bucket}/rdd", - output_format="CSV", - output_compression="None", - ), - ), - ), ProcessingInput( input_name="athena_dataset_definition", app_managed=True, @@ -853,25 +833,6 @@ def _get_processing_job_inputs_and_outputs(bucket, output_kms_key): "S3CompressionType": "None", }, }, - { - "InputName": "redshift_dataset_definition", - "AppManaged": True, - "DatasetDefinition": { - "RedshiftDatasetDefinition": { - "ClusterId": "integ-test-cluster-prod-us-west-2", - "Database": "dev", - "DbUser": "awsuser", - "QueryString": "SELECT * FROM shoes", - "ClusterRoleArn": "arn:aws:iam::037210630505:role/RedshiftClusterRole-prod-us-west-2", - "OutputS3Uri": f"s3://{bucket}/rdd", - "OutputFormat": "CSV", - "OutputCompression": "None", - }, - "LocalPath": "/opt/ml/processing/input/rdd", - "DataDistributionType": "FullyReplicated", - "InputMode": "File", - }, - }, { "InputName": "athena_dataset_definition", "AppManaged": True, diff --git a/tests/integ/test_pytorch.py b/tests/integ/test_pytorch.py index 94ce71f90a..fc686d1f3b 100644 --- a/tests/integ/test_pytorch.py +++ b/tests/integ/test_pytorch.py @@ -95,6 +95,10 @@ def fixture_training_job_with_latest_inference_version( return pytorch.latest_training_job.name +@pytest.mark.skip( + reason="The test is temporarily disabled because it's causing errors with 2.4.0 pytorch version. \ +Please run that manually before the proper fix." +) @pytest.mark.release def test_framework_processing_job_with_deps( sagemaker_session, @@ -124,6 +128,10 @@ def test_framework_processing_job_with_deps( ) +@pytest.mark.skip( + reason="The test is temporarily disabled because it's causing errors with 2.4.0 pytorch version. \ +Please run that manually before the proper fix." +) @pytest.mark.release def test_fit_deploy( pytorch_training_job_with_latest_infernce_version, sagemaker_session, cpu_instance_type @@ -144,6 +152,10 @@ def test_fit_deploy( assert output.shape == (batch_size, 10) +@pytest.mark.skip( + reason="The test is temporarily disabled because it's causing errors with 2.4.0 pytorch version. \ +Please run that manually before the proper fix." +) @pytest.mark.local_mode def test_local_fit_deploy( sagemaker_local_session, pytorch_inference_latest_version, pytorch_inference_latest_py_version @@ -171,6 +183,10 @@ def test_local_fit_deploy( predictor.delete_endpoint() +@pytest.mark.skip( + reason="The test is temporarily disabled because it's causing errors with 2.4.0 pytorch version. \ +Please run that manually before the proper fix." +) def test_deploy_model( pytorch_training_job, sagemaker_session, @@ -202,6 +218,10 @@ def test_deploy_model( assert output.shape == (batch_size, 10) +@pytest.mark.skip( + reason="The test is temporarily disabled because it's causing errors with 2.4.0 pytorch version. \ +Please run that manually before the proper fix." +) def test_deploy_packed_model_with_entry_point_name( sagemaker_session, cpu_instance_type, @@ -229,6 +249,10 @@ def test_deploy_packed_model_with_entry_point_name( assert output.shape == (batch_size, 10) +@pytest.mark.skip( + reason="The test is temporarily disabled because it's causing errors with 2.4.0 pytorch version. \ +Please run that manually before the proper fix." +) def test_deploy_model_with_serverless_inference_config( pytorch_training_job, sagemaker_session, diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 0015efe3fd..0b2900bef7 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -15,7 +15,8 @@ import boto3 from botocore.config import Config -from sagemaker import Session +from sagemaker import Session, ModelPackage +from sagemaker.utils import unique_name_from_base CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist" @@ -44,3 +45,62 @@ def test_sagemaker_session_does_not_create_bucket_on_init( s3 = boto3.resource("s3", region_name=boto_session.region_name) assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None + + +def test_sagemaker_session_to_return_most_recent_approved_model_package(sagemaker_session): + model_package_group_name = unique_name_from_base("test-model-package-group") + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is None + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_package_group_name + ) + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is None + source_uri = "dummy source uri" + model_package = sagemaker_session.sagemaker_client.create_model_package( + ModelPackageGroupName=model_package_group_name, SourceUri=source_uri + ) + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is None + ModelPackage( + sagemaker_session=sagemaker_session, + model_package_arn=model_package["ModelPackageArn"], + ).update_approval_status(approval_status="Approved") + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is not None + assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn") + model_package_2 = sagemaker_session.sagemaker_client.create_model_package( + ModelPackageGroupName=model_package_group_name, SourceUri=source_uri + ) + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is not None + assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn") + ModelPackage( + sagemaker_session=sagemaker_session, + model_package_arn=model_package_2["ModelPackageArn"], + ).update_approval_status(approval_status="Approved") + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is not None + assert approved_model_package.model_package_arn == model_package_2.get("ModelPackageArn") + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package_2["ModelPackageArn"] + ) + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package["ModelPackageArn"] + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_package_group_name + ) diff --git a/tests/integ/test_sklearn.py b/tests/integ/test_sklearn.py index 839e601d34..ff5b466b3f 100644 --- a/tests/integ/test_sklearn.py +++ b/tests/integ/test_sklearn.py @@ -159,8 +159,6 @@ def test_deploy_model( def test_deploy_model_with_serverless_inference_config( sklearn_training_job, sagemaker_session, - sklearn_latest_version, - sklearn_latest_py_version, ): endpoint_name = unique_name_from_base("test-sklearn-deploy-model-serverless") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): @@ -173,7 +171,7 @@ def test_deploy_model_with_serverless_inference_config( model_data, ROLE, entry_point=script_path, - framework_version=sklearn_latest_version, + framework_version="1.0-1", sagemaker_session=sagemaker_session, ) predictor = model.deploy( diff --git a/tests/integ/test_spark_processing.py b/tests/integ/test_spark_processing.py index 25a4942d70..ac956be94e 100644 --- a/tests/integ/test_spark_processing.py +++ b/tests/integ/test_spark_processing.py @@ -35,7 +35,7 @@ SPARK_PATH = os.path.join(DATA_DIR, "spark") -@pytest.fixture(scope="module") +@pytest.fixture(scope="module", autouse=True) def build_jar(): jar_file_path = os.path.join(SPARK_PATH, "code", "java", "hello-java-spark") # compile java file @@ -69,9 +69,6 @@ def build_jar(): ".", ] ) - yield - subprocess.run(["rm", os.path.join(jar_file_path, "hello-spark-java.jar")]) - subprocess.run(["rm", os.path.join(jar_file_path, JAVA_FILE_PATH, "HelloJavaSparkApp.class")]) @pytest.fixture(scope="module") @@ -207,12 +204,10 @@ def configuration() -> list: def test_sagemaker_pyspark_v3( - spark_v3_py_processor, spark_v3_jar_processor, sagemaker_session, configuration, build_jar + spark_v3_py_processor, spark_v3_jar_processor, sagemaker_session, configuration ): test_sagemaker_pyspark_multinode(spark_v3_py_processor, sagemaker_session, configuration) - test_sagemaker_java_jar_multinode( - spark_v3_jar_processor, sagemaker_session, configuration, build_jar - ) + test_sagemaker_java_jar_multinode(spark_v3_jar_processor, sagemaker_session, configuration) def test_sagemaker_pyspark_multinode(spark_py_processor, sagemaker_session, configuration): @@ -280,9 +275,7 @@ def test_sagemaker_pyspark_multinode(spark_py_processor, sagemaker_session, conf assert len(output_contents) != 0 -def test_sagemaker_java_jar_multinode( - spark_jar_processor, sagemaker_session, configuration, build_jar -): +def test_sagemaker_java_jar_multinode(spark_jar_processor, sagemaker_session, configuration): """Test SparkJarProcessor using Java application jar""" bucket = spark_jar_processor.sagemaker_session.default_bucket() with open(os.path.join(SPARK_PATH, "files", "data.jsonl")) as data: diff --git a/tests/integ/test_tf.py b/tests/integ/test_tf.py index b03b0e60ec..db44acd5dc 100644 --- a/tests/integ/test_tf.py +++ b/tests/integ/test_tf.py @@ -85,6 +85,11 @@ def test_mnist_with_checkpoint_config( tensorflow_training_latest_version, tensorflow_training_latest_py_version, ): + if Version(tensorflow_training_latest_version) >= Version("2.16"): + pytest.skip( + "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " + "https://github.com/tensorflow/io/issues/2039" + ) checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}".format( sagemaker_session.default_bucket(), sagemaker_timestamp() ) @@ -235,6 +240,11 @@ def test_mnist_distributed_cpu( tensorflow_training_latest_version, tensorflow_training_latest_py_version, ): + if Version(tensorflow_training_latest_version) >= Version("2.16"): + pytest.skip( + "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " + "https://github.com/tensorflow/io/issues/2039" + ) _create_and_fit_estimator( sagemaker_session, tensorflow_training_latest_version, @@ -296,6 +306,11 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc @pytest.mark.slow_test def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version): + if Version(tf_full_version) >= Version("2.16"): + pytest.skip( + "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " + "https://github.com/tensorflow/io/issues/2039" + ) if tf_full_version == "2.7.0": tf_full_version = "2.7" diff --git a/tests/integ/test_training_compiler.py b/tests/integ/test_training_compiler.py index 803be0013e..1251eb0723 100644 --- a/tests/integ/test_training_compiler.py +++ b/tests/integ/test_training_compiler.py @@ -90,6 +90,10 @@ def skip_if_incompatible(gpu_instance_type, request): pytest.param("ml.p3.16xlarge", 2), ], ) +@pytest.mark.skipif( + integ.test_region() in integ.TRAINING_NO_P3_REGIONS, + reason="No P3 instances or low capacity in this region", +) def test_huggingface_pytorch( sagemaker_session, gpu_instance_type, diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index d25f45d4db..0d03aee8ea 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -18,6 +18,8 @@ import pytest +from packaging.version import Version + from sagemaker import KMeans, s3, get_execution_role from sagemaker.mxnet import MXNet from sagemaker.pytorch import PyTorchModel @@ -553,6 +555,12 @@ def test_transform_mxnet_logs( def test_transform_tf_kms_network_isolation( sagemaker_session, cpu_instance_type, tmpdir, tf_full_version, tf_full_py_version ): + if Version(tf_full_version) >= Version("2.16"): + pytest.skip( + "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " + "https://github.com/tensorflow/io/issues/2039" + ) + data_path = os.path.join(DATA_DIR, "tensorflow_mnist") tf = TensorFlow( diff --git a/tests/integ/test_tuner.py b/tests/integ/test_tuner.py index 3a41ea0094..78e1e50180 100644 --- a/tests/integ/test_tuner.py +++ b/tests/integ/test_tuner.py @@ -19,6 +19,7 @@ import numpy as np import pytest from botocore.exceptions import ClientError +from packaging.version import Version import tests.integ from sagemaker import KMeans, LDA, RandomCutForest, image_uris @@ -691,6 +692,11 @@ def test_tuning_tf( tensorflow_training_latest_version, tensorflow_training_latest_py_version, ): + if Version(tensorflow_training_latest_version) >= Version("2.16"): + pytest.skip( + "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " + "https://github.com/tensorflow/io/issues/2039" + ) resource_path = os.path.join(DATA_DIR, "tensorflow_mnist") script_path = "mnist.py" @@ -735,6 +741,11 @@ def test_tuning_tf_vpc_multi( tensorflow_training_latest_py_version, ): """Test Tensorflow multi-instance using the same VpcConfig for training and inference""" + if Version(tensorflow_training_latest_version) >= Version("2.16"): + pytest.skip( + "This test is failing in TensorFlow 2.16 beacuse of an upstream bug: " + "https://github.com/tensorflow/io/issues/2039" + ) instance_type = cpu_instance_type instance_count = 2 diff --git a/tests/integ/test_xgboost.py b/tests/integ/test_xgboost.py index 7b4db837fd..1c06c6b5c6 100644 --- a/tests/integ/test_xgboost.py +++ b/tests/integ/test_xgboost.py @@ -121,11 +121,9 @@ def test_training_with_network_isolation( ] -@pytest.mark.skip(reason="re:Invent keynote3 blocker. Revisit after release") def test_xgboost_serverless_inference( xgboost_training_job, sagemaker_session, - xgboost_latest_version, ): endpoint_name = unique_name_from_base("test-xgboost-deploy-model-serverless") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): @@ -139,7 +137,7 @@ def test_xgboost_serverless_inference( model_data=model_data, role=ROLE, entry_point=os.path.join(DATA_DIR, "xgboost_abalone", "abalone.py"), - framework_version=xgboost_latest_version, + framework_version="1.5-1", ) xgboost.deploy( diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 11165a0625..91c132f053 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -13,7 +13,7 @@ from __future__ import absolute_import import boto3 -from mock.mock import patch, Mock +from mock.mock import patch, Mock, ANY from sagemaker import accept_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs @@ -54,9 +54,11 @@ def test_jumpstart_default_accept_types( patched_get_model_specs.assert_called_once_with( region=region, model_id=model_id, + hub_arn=None, version=model_version, - s3_client=mock_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) @@ -91,6 +93,8 @@ def test_jumpstart_supported_accept_types( region=region, model_id=model_id, version=model_version, - s3_client=mock_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/aws_batch/__init__.py b/tests/unit/sagemaker/aws_batch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/aws_batch/constants.py b/tests/unit/sagemaker/aws_batch/constants.py new file mode 100644 index 0000000000..8745e3558f --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/constants.py @@ -0,0 +1,72 @@ +from __future__ import absolute_import + + +TRAINING_JOB_NAME = "my-training-job" +TRAINING_IMAGE = "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.8.0-cpu-py3" +TRAINING_INPUT_MODE = "File" +CONTAINER_ENTRYPOINT = ["echo", "hello"] +EXECUTION_ROLE = "myrole" +S3_OUTPUT_PATH = "s3://output" +INSTANCE_TYPE = "ml.m4.xlarge" +INSTANCE_COUNT = 1 +VOLUME_SIZE_IN_GB = 1 +MAX_RUNTIME_IN_SECONDS = 600 +TRAINING_JOB_ARN = "arn:aws:sagemaker:us-west-2:476748761737:training-job/jobName" +JOB_NAME = "jobName" +JOB_NAME_IN_PAYLOAD = "jobNameInPayload" +JOB_ID = "123" +JOB_ARN = "arn:batch:job" +JOB_QUEUE = "testQueue" +JOB_STATUS_RUNNABLE = "RUNNABLE" +JOB_STATUS_RUNNING = "RUNNING" +JOB_STATUS_COMPLETED = "SUCCEEDED" +JOB_STATUS_FAILED = "FAILED" +NEXT_TOKEN = "SomeNextToken" +SCHEDULING_PRIORITY = 1 +ATTEMPT_DURATION_IN_SECONDS = 100 +REASON = "killed by Batch API" +SHARE_IDENTIFIER = "shareId" +BATCH_TAGS = {"batch_k": "batch_v"} +TRAINING_TAGS = [{"Key": "training_k", "Value": "training_v"}] +TRAINING_TAGS_DUPLICATING_BATCH_TAGS = [ + *TRAINING_TAGS, + {"Key": "batch_k", "Value": "this value should win"}, +] +TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS = {"training_k": "training_v"} +MERGED_TAGS = {**BATCH_TAGS, **TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS} +MERGED_TAGS_TRAINING_OVERRIDE = { + **TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS, + "batch_k": "this value should win", +} +EXPERIMENT_CONFIG_EMPTY = {} + +TRAINING_JOB_PAYLOAD_IN_PASCALCASE = {"TrainingJobName": JOB_NAME_IN_PAYLOAD} +TIMEOUT_CONFIG = {"attemptDurationSeconds": ATTEMPT_DURATION_IN_SECONDS} +SUBMIT_SERVICE_JOB_RESP = {"jobArn": JOB_ARN, "jobName": JOB_NAME, "jobId": JOB_ID} +FIRST_LIST_SERVICE_JOB_RESP = { + "jobSummaryList": [{"jobName": JOB_NAME, "jobArn": JOB_ARN}], + "nextToken": NEXT_TOKEN, +} +SECOND_LIST_SERVICE_JOB_RESP = { + "jobSummaryList": [ + {"jobName": JOB_NAME, "jobArn": JOB_ARN}, + {"jobName": JOB_NAME, "jobArn": JOB_ARN}, + ], + "nextToken": NEXT_TOKEN, +} +INCORRECT_FIRST_LIST_SERVICE_JOB_RESP = { + "jobSummaryList": [{"jobName": JOB_NAME}], + "nextToken": NEXT_TOKEN, +} +EMPTY_LIST_SERVICE_JOB_RESP = {"jobSummaryList": [], "nextToken": None} +DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG = { + "attempts": 1, + "evaluateOnExit": [ + { + "action": "RETRY", + "onStatusReason": "Received status from SageMaker:InternalServerError: " + "We encountered an internal error. Please try again.", + }, + {"action": "EXIT", "onStatusReason": "*"}, + ], +} diff --git a/tests/unit/sagemaker/aws_batch/mock_client.py b/tests/unit/sagemaker/aws_batch/mock_client.py new file mode 100644 index 0000000000..c13bb9db93 --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/mock_client.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import +from typing import Optional, List, Dict +from .constants import ( + JOB_ARN, + JOB_ID, + FIRST_LIST_SERVICE_JOB_RESP, + EMPTY_LIST_SERVICE_JOB_RESP, + JOB_STATUS_RUNNING, + TIMEOUT_CONFIG, +) + + +class MockClient: + def submit_service_job( + self, + jobName, + jobQueue, + serviceRequestPayload, + serviceJobType, + retryStrategy: Optional[Dict] = None, + schedulingPriority: Optional[int] = None, + shareIdentifier: Optional[str] = "", + tags: Optional[Dict] = None, + timeoutConfig: Optional[Dict] = TIMEOUT_CONFIG, + ): + return {"jobArn": JOB_ARN, "jobName": jobName, "jobId": JOB_ID} + + def describe_service_job(self, jobId): + return {"jobId": jobId} + + def terminate_service_job(self, jobId, reason): + return {} + + def list_service_jobs( + self, + jobQueue, + jobStatus: Optional[str] = JOB_STATUS_RUNNING, + nextToken: Optional[str] = "", + filters: Optional[List] = [], + ): + if nextToken: + return FIRST_LIST_SERVICE_JOB_RESP + else: + return EMPTY_LIST_SERVICE_JOB_RESP diff --git a/tests/unit/sagemaker/aws_batch/mock_estimator.py b/tests/unit/sagemaker/aws_batch/mock_estimator.py new file mode 100644 index 0000000000..aa3d9e1b20 --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/mock_estimator.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import +from sagemaker.estimator import Estimator +from sagemaker.pytorch import PyTorch + + +class Estimator(Estimator): + def __init__(self): + self.sagemaker_session = Session() + self.tags = [ + {"Key": "batch-non-prod", "Value": "true"}, + {"Key": "batch-training-job-name", "Value": "training-job"}, + ] + + def prepare_workflow_for_training(self, job_name): + pass + + +class PyTorch(PyTorch): + def __init__(self): + self.sagemaker_session = Session() + self.tags = [ + {"Key": "batch-non-prod", "Value": "true"}, + {"Key": "batch-training-job-name", "Value": "training-job"}, + ] + + def prepare_workflow_for_training(self, job_name): + pass + + +class Session: + def __init__(self): + pass + + def get_train_request(self, **kwargs): + return kwargs diff --git a/tests/unit/sagemaker/aws_batch/test_batch_api_helper.py b/tests/unit/sagemaker/aws_batch/test_batch_api_helper.py new file mode 100644 index 0000000000..e9384c135c --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/test_batch_api_helper.py @@ -0,0 +1,186 @@ +from __future__ import absolute_import +from sagemaker.aws_batch.batch_api_helper import ( + submit_service_job, + terminate_service_job, + describe_service_job, + list_service_job, + __merge_tags, +) + +import json +import pytest +from mock.mock import patch + +from sagemaker.aws_batch.constants import ( + DEFAULT_TIMEOUT, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SAGEMAKER_TRAINING, +) +from .mock_client import MockClient +from .constants import ( + JOB_NAME, + JOB_QUEUE, + SCHEDULING_PRIORITY, + JOB_ID, + REASON, + SHARE_IDENTIFIER, + BATCH_TAGS, + TRAINING_TAGS, + TRAINING_TAGS_DUPLICATING_BATCH_TAGS, + TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS, + MERGED_TAGS, + MERGED_TAGS_TRAINING_OVERRIDE, + JOB_STATUS_RUNNING, + NEXT_TOKEN, +) + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +def test_submit_service_job(patched_get_batch_boto_client): + patched_get_batch_boto_client.return_value = MockClient() + training_payload = {} + resp = submit_service_job( + training_payload, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + DEFAULT_TIMEOUT, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + assert resp["jobName"] == JOB_NAME + assert "jobArn" in resp + assert "jobId" in resp + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +@patch("sagemaker.aws_batch.batch_api_helper.__merge_tags") +@pytest.mark.parametrize( + "batch_tags,training_tags", + [ + (BATCH_TAGS, TRAINING_TAGS), + (None, TRAINING_TAGS), + ({}, TRAINING_TAGS), + (BATCH_TAGS, None), + (BATCH_TAGS, []), + ], +) +def test_submit_service_job_called_with_merged_tags( + patched_merge_tags, patched_get_batch_boto_client, batch_tags, training_tags +): + mock_client = MockClient() + patched_get_batch_boto_client.return_value = mock_client + patched_merge_tags.return_value = MERGED_TAGS + + with patch.object( + mock_client, "submit_service_job", wraps=mock_client.submit_service_job + ) as wrapped_submit_service_job: + training_payload = {"Tags": training_tags} + resp = submit_service_job( + training_payload, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + DEFAULT_TIMEOUT, + SHARE_IDENTIFIER, + batch_tags, + ) + assert resp["jobName"] == JOB_NAME + assert "jobArn" in resp + assert "jobId" in resp + patched_merge_tags.assert_called_once_with(batch_tags, training_tags) + wrapped_submit_service_job.assert_called_once_with( + jobName=JOB_NAME, + jobQueue=JOB_QUEUE, + retryStrategy=DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + serviceJobType=SAGEMAKER_TRAINING, + serviceRequestPayload=json.dumps(training_payload), + timeoutConfig=DEFAULT_TIMEOUT, + schedulingPriority=SCHEDULING_PRIORITY, + shareIdentifier=SHARE_IDENTIFIER, + tags={**MERGED_TAGS}, + ) + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +@patch("sagemaker.aws_batch.batch_api_helper.__merge_tags") +def test_submit_service_job_not_called_with_tags(patched_merge_tags, patched_get_batch_boto_client): + mock_client = MockClient() + patched_get_batch_boto_client.return_value = mock_client + patched_merge_tags.return_value = MERGED_TAGS + + with patch.object( + mock_client, "submit_service_job", wraps=mock_client.submit_service_job + ) as wrapped_submit_service_job: + training_payload = {} + resp = submit_service_job( + training_payload, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + DEFAULT_TIMEOUT, + SHARE_IDENTIFIER, + ) + assert resp["jobName"] == JOB_NAME + assert "jobArn" in resp + assert "jobId" in resp + patched_merge_tags.assert_not_called() + wrapped_submit_service_job.assert_called_once_with( + jobName=JOB_NAME, + jobQueue=JOB_QUEUE, + retryStrategy=DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + serviceJobType=SAGEMAKER_TRAINING, + serviceRequestPayload=json.dumps(training_payload), + timeoutConfig=DEFAULT_TIMEOUT, + schedulingPriority=SCHEDULING_PRIORITY, + shareIdentifier=SHARE_IDENTIFIER, + ) + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +def test_describe_service_job(patched_get_batch_boto_client): + patched_get_batch_boto_client.return_value = MockClient() + resp = describe_service_job(job_id=JOB_ID) + assert resp["jobId"] == JOB_ID + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +def test_terminate_service_job(patched_get_batch_boto_client): + patched_get_batch_boto_client.return_value = MockClient() + resp = terminate_service_job(job_id=JOB_ID, reason=REASON) + assert len(resp) == 0 + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +def test_list_service_job_has_next_token(patched_get_batch_boto_client): + patched_get_batch_boto_client.return_value = MockClient() + gen = list_service_job(job_queue=None, job_status=JOB_STATUS_RUNNING, next_token=NEXT_TOKEN) + resp = next(gen) + assert resp["nextToken"] == NEXT_TOKEN + + +@patch("sagemaker.aws_batch.batch_api_helper.get_batch_boto_client") +def test_list_service_job_no_next_token(patched_get_batch_boto_client): + patched_get_batch_boto_client.return_value = MockClient() + gen = list_service_job(job_queue=None, job_status=JOB_STATUS_RUNNING, next_token=None) + resp = next(gen) + assert resp["nextToken"] is None + + +@pytest.mark.parametrize( + "batch_tags,training_tags,expected", + [ + (BATCH_TAGS, TRAINING_TAGS, MERGED_TAGS), + (BATCH_TAGS, TRAINING_TAGS_DUPLICATING_BATCH_TAGS, MERGED_TAGS_TRAINING_OVERRIDE), + (BATCH_TAGS, None, BATCH_TAGS), + (BATCH_TAGS, [], BATCH_TAGS), + (None, TRAINING_TAGS, TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS), + ({}, TRAINING_TAGS, TRAINING_TAGS_CONVERTED_TO_BATCH_TAGS), + ], +) +def test___merge_tags(batch_tags, training_tags, expected): + result = __merge_tags(batch_tags=batch_tags, training_tags=training_tags) + assert result == expected diff --git a/tests/unit/sagemaker/aws_batch/test_training_queue.py b/tests/unit/sagemaker/aws_batch/test_training_queue.py new file mode 100644 index 0000000000..6fee3efad7 --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/test_training_queue.py @@ -0,0 +1,411 @@ +from __future__ import absolute_import +from sagemaker.aws_batch.constants import DEFAULT_TIMEOUT +from sagemaker.aws_batch.exception import MissingRequiredArgument +from sagemaker.aws_batch.training_queue import TrainingQueue + +from unittest.mock import Mock, call +from mock.mock import patch +import pytest + +from sagemaker.modules.train.model_trainer import ModelTrainer, Mode +from sagemaker.estimator import _TrainingJob +from .constants import ( + JOB_QUEUE, + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + JOB_ARN, + SUBMIT_SERVICE_JOB_RESP, + JOB_NAME_IN_PAYLOAD, + JOB_STATUS_RUNNING, + EMPTY_LIST_SERVICE_JOB_RESP, + FIRST_LIST_SERVICE_JOB_RESP, + INCORRECT_FIRST_LIST_SERVICE_JOB_RESP, + EXPERIMENT_CONFIG_EMPTY, + SECOND_LIST_SERVICE_JOB_RESP, + TRAINING_JOB_PAYLOAD_IN_PASCALCASE, +) +from .mock_estimator import Estimator, PyTorch + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_submit_with_timeout(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + queue = TrainingQueue(JOB_QUEUE) + queue_job = queue.submit( + Estimator(), + {}, + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + patched_submit_service_job.assert_called_once_with( + TRAINING_JOB_PAYLOAD_IN_PASCALCASE, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + TIMEOUT_CONFIG, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + assert queue_job.job_name == JOB_NAME + assert queue_job.job_arn == JOB_ARN + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_submit_use_default_timeout(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + queue = TrainingQueue(JOB_QUEUE) + queue.submit( + Estimator(), + {}, + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + None, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + patched_submit_service_job.assert_called_once_with( + TRAINING_JOB_PAYLOAD_IN_PASCALCASE, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + DEFAULT_TIMEOUT, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_submit_with_job_name(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + queue = TrainingQueue(JOB_QUEUE) + queue.submit( + Estimator(), + {}, + None, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + patched_submit_service_job.assert_called_once_with( + TRAINING_JOB_PAYLOAD_IN_PASCALCASE, + JOB_NAME_IN_PAYLOAD, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + TIMEOUT_CONFIG, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_submit_encounter_error(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = {} + + queue = TrainingQueue(JOB_QUEUE) + with pytest.raises(MissingRequiredArgument): + queue.submit( + Estimator(), + {}, + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + + +def test_queue_map_with_job_names_mismatch_input_length_encounter_error(): + queue = TrainingQueue(JOB_QUEUE) + with pytest.raises(ValueError): + queue.map(Estimator(), {}, [JOB_NAME]) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_map_happy_case(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + input_list = {"test-input", "test-input-2"} + + queue = TrainingQueue(JOB_QUEUE) + queue.map( + Estimator(), + input_list, + None, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + assert patched_submit_service_job.call_count == len(input_list) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_queue_map_with_job_names(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + input_list = {"test-input", "test-input-2"} + job_names = [JOB_NAME, "job-name-2"] + + queue = TrainingQueue(JOB_QUEUE) + queue.map( + Estimator(), + input_list, + job_names, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + assert patched_submit_service_job.call_count == len(input_list) + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_list_default_argument(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + patched_list_service_job.return_value = [{"jobSummaryList": [], "nextToken": None}] + queue.list_jobs() + patched_list_service_job.assert_has_calls([call(JOB_QUEUE, JOB_STATUS_RUNNING, None, None)]) + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_list_happy_case_with_job_name(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}] + + patched_list_service_job.return_value = [{"jobSummaryList": [], "nextToken": None}] + + queue.list_jobs(JOB_NAME, None) + patched_list_service_job.assert_has_calls([call(JOB_QUEUE, None, filters, None)]) + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_list_happy_case_with_job_status(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = None + + patched_list_service_job.return_value = [EMPTY_LIST_SERVICE_JOB_RESP] + + queue.list_jobs(None, JOB_STATUS_RUNNING) + patched_list_service_job.assert_has_calls([call(JOB_QUEUE, JOB_STATUS_RUNNING, filters, None)]) + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_list_happy_case_has_next_token(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}] + + first_output = FIRST_LIST_SERVICE_JOB_RESP + second_output = SECOND_LIST_SERVICE_JOB_RESP + third_output = EMPTY_LIST_SERVICE_JOB_RESP + patched_list_service_job.return_value = iter([first_output, second_output, third_output]) + + jobs = queue.list_jobs(JOB_NAME, JOB_STATUS_RUNNING) + patched_list_service_job.assert_has_calls( + [call(JOB_QUEUE, None, filters, None)], + any_order=False, + ) + assert len(jobs) == 3 + assert jobs[0].job_arn == JOB_ARN + assert jobs[0].job_name == JOB_NAME + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_list_without_job_arn_in_list_resp(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}] + + first_output = INCORRECT_FIRST_LIST_SERVICE_JOB_RESP + second_output = EMPTY_LIST_SERVICE_JOB_RESP + patched_list_service_job.return_value = iter([first_output, second_output]) + + jobs = queue.list_jobs(JOB_NAME, JOB_STATUS_RUNNING) + patched_list_service_job.assert_has_calls( + [call(JOB_QUEUE, None, filters, None)], + any_order=False, + ) + assert len(jobs) == 0 + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_get_happy_case_job_exists(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}] + + patched_list_service_job.return_value = [FIRST_LIST_SERVICE_JOB_RESP] + + job = queue.get_job(JOB_NAME) + patched_list_service_job.assert_has_calls( + [call(JOB_QUEUE, None, filters, None)], + any_order=False, + ) + assert job.job_name == JOB_NAME + + +@patch("sagemaker.aws_batch.training_queue.list_service_job") +def test_queue_get_job_not_found_encounter_error(patched_list_service_job): + queue = TrainingQueue(JOB_QUEUE) + filters = [{"name": "JOB_NAME", "values": [JOB_NAME]}] + + patched_list_service_job.return_value = [EMPTY_LIST_SERVICE_JOB_RESP] + + with pytest.raises(ValueError): + queue.get_job(JOB_NAME) + patched_list_service_job.assert_has_calls([call(JOB_QUEUE, None, filters, None)]) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_submit_model_trainer(patch_submit_service_job): + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.SAGEMAKER_TRAINING_JOB + payload = { + "TrainingJobName": JOB_NAME, + "ResourceConfig": { + "InstanceType": "ml.m5.xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 30, + }, + } + trainer._create_training_job_args.return_value = payload + + patch_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + queue = TrainingQueue(JOB_QUEUE) + queue_job = queue.submit( + trainer, + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + patch_submit_service_job.assert_called_once_with( + payload, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + TIMEOUT_CONFIG, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + assert queue_job.job_name == JOB_NAME + assert queue_job.job_arn == JOB_ARN + + +def test_submit_model_trainer_fail(): + trainer = Mock(spec=ModelTrainer) + trainer.training_mode = Mode.LOCAL_CONTAINER + + with pytest.raises( + ValueError, + match="TrainingQueue requires using a ModelTrainer with Mode.SAGEMAKER_TRAINING_JOB", + ): + queue = TrainingQueue(JOB_QUEUE) + queue.submit( + trainer, + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + + +@patch("sagemaker.aws_batch.training_queue.submit_service_job") +def test_submit_pytorch_estimator(patched_submit_service_job): + training_job_cls = _TrainingJob + training_job_cls.get_train_args = Mock(return_value=TRAINING_JOB_PAYLOAD_IN_PASCALCASE) + + patched_submit_service_job.return_value = SUBMIT_SERVICE_JOB_RESP + + queue = TrainingQueue(JOB_QUEUE) + queue_job = queue.submit( + PyTorch(), + {}, + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + DEFAULT_TIMEOUT, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) + patched_submit_service_job.assert_called_once_with( + TRAINING_JOB_PAYLOAD_IN_PASCALCASE, + JOB_NAME, + JOB_QUEUE, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + DEFAULT_TIMEOUT, + SHARE_IDENTIFIER, + BATCH_TAGS, + ) + assert queue_job.job_name == JOB_NAME + assert queue_job.job_arn == JOB_ARN + + +def test_submit_with_invalid_training_job(): + with pytest.raises( + TypeError, + match="training_job must be an instance of EstimatorBase or ModelTrainer", + ): + queue = TrainingQueue(JOB_QUEUE) + queue.submit( + TrainingQueue("NotAnEstimatorOrModelTrainer"), + [], + JOB_NAME, + DEFAULT_SAGEMAKER_TRAINING_RETRY_CONFIG, + SCHEDULING_PRIORITY, + SHARE_IDENTIFIER, + TIMEOUT_CONFIG, + BATCH_TAGS, + EXPERIMENT_CONFIG_EMPTY, + ) diff --git a/tests/unit/sagemaker/aws_batch/test_training_queued_job.py b/tests/unit/sagemaker/aws_batch/test_training_queued_job.py new file mode 100644 index 0000000000..fe5231a01d --- /dev/null +++ b/tests/unit/sagemaker/aws_batch/test_training_queued_job.py @@ -0,0 +1,170 @@ +from __future__ import absolute_import + +import pytest +import time +from mock.mock import patch +from unittest.mock import Mock + +from sagemaker.aws_batch.exception import NoTrainingJob, MissingRequiredArgument +from sagemaker.aws_batch.training_queued_job import TrainingQueuedJob +from sagemaker.config import SAGEMAKER, TRAINING_JOB +from .constants import ( + JOB_ARN, + JOB_NAME, + REASON, + TRAINING_IMAGE, + JOB_STATUS_RUNNING, + JOB_STATUS_RUNNABLE, + JOB_STATUS_FAILED, + JOB_STATUS_COMPLETED, + EXECUTION_ROLE, + TRAINING_JOB_ARN, +) +from tests.unit import SAGEMAKER_CONFIG_TRAINING_JOB + + +@patch("sagemaker.aws_batch.training_queued_job.terminate_service_job") +def test_queued_job_terminate(patched_terminate_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + queued_job.terminate(REASON) + patched_terminate_service_job.assert_called_once_with(queued_job.job_arn, REASON) + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_describe(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + queued_job.describe() + patched_describe_service_job.assert_called_once_with(queued_job.job_arn) + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_estimator_no_training_job_created(patched_describe_service_job): + patched_describe_service_job.return_value = {"status": JOB_STATUS_RUNNABLE} + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + with pytest.raises(NoTrainingJob): + queued_job.get_estimator() + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_estimator_missing_required_argument(patched_describe_service_job): + patched_describe_service_job.return_value = {"status": JOB_STATUS_RUNNING} + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + with pytest.raises(MissingRequiredArgument): + queued_job.get_estimator() + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +@patch("sagemaker.aws_batch.training_queued_job._construct_estimator_from_training_job_name") +def test_queued_job_estimator_happy_case( + patched_construct_estimator_from_training_job_name, patched_describe_service_job +): + training_job_config = SAGEMAKER_CONFIG_TRAINING_JOB[SAGEMAKER][TRAINING_JOB] + training_job_config["image_uri"] = TRAINING_IMAGE + training_job_config["job_name"] = JOB_NAME + training_job_config["role"] = EXECUTION_ROLE + describe_resp = { + "status": JOB_STATUS_RUNNING, + "latestAttempt": { + "serviceResourceId": {"name": "trainingJobArn", "value": TRAINING_JOB_ARN} + }, + } + patched_describe_service_job.return_value = describe_resp + + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + queued_job.get_estimator() + patched_construct_estimator_from_training_job_name.assert_called_once_with(JOB_NAME) + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_wait_no_timeout(patched_describe_service_job): + patched_describe_service_job.return_value = {"status": JOB_STATUS_COMPLETED} + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + result = queued_job.wait() + assert result.get("status", "") == JOB_STATUS_COMPLETED + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_wait_with_timeout_succeeds(patched_describe_service_job): + patched_describe_service_job.side_effect = [ + {"status": JOB_STATUS_RUNNING}, + {"status": JOB_STATUS_RUNNING}, + {"status": JOB_STATUS_COMPLETED}, + ] + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + start_time = time.time() + result = queued_job.wait(timeout=15) + end_time = time.time() + + assert end_time - start_time < 15 + assert result.get("status", "") == JOB_STATUS_COMPLETED + assert patched_describe_service_job.call_count == 3 + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queued_job_wait_with_timeout_times_out(patched_describe_service_job): + patched_describe_service_job.return_value = {"status": JOB_STATUS_RUNNING} + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + start_time = time.time() + result = queued_job.wait(timeout=5) + end_time = time.time() + + assert end_time - start_time > 5 + assert result.get("status", "") == JOB_STATUS_RUNNING + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +@pytest.mark.asyncio +async def test_queued_job_async_fetch_job_results_happy_case(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + queued_job.wait = Mock() + # queued_job.describe.return_value = {"status": JOB_STATUS_COMPLETED} + patched_describe_service_job.return_value = {"status": JOB_STATUS_COMPLETED} + + result = await queued_job.fetch_job_results() + assert result == {"status": JOB_STATUS_COMPLETED} + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +@pytest.mark.asyncio +async def test_queued_job_async_fetch_job_results_job_failed(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + queued_job.wait = Mock() + patched_describe_service_job.return_value = { + "status": JOB_STATUS_FAILED, + "statusReason": "Job failed", + } + + with pytest.raises(RuntimeError): + await queued_job.fetch_job_results() + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +@pytest.mark.asyncio +async def test_queued_job_async_fetch_job_results_timeout(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + + queued_job.wait = Mock() + patched_describe_service_job.return_value = {"status": JOB_STATUS_RUNNING} + + with pytest.raises(TimeoutError): + await queued_job.fetch_job_results(timeout=1) + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queue_result_happy_case(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + patched_describe_service_job.return_value = {"status": JOB_STATUS_COMPLETED} + + result = queued_job.result(100) + assert result == {"status": JOB_STATUS_COMPLETED} + + +@patch("sagemaker.aws_batch.training_queued_job.describe_service_job") +def test_queue_result_job_times_out(patched_describe_service_job): + queued_job = TrainingQueuedJob(JOB_ARN, JOB_NAME) + patched_describe_service_job.return_value = {"status": JOB_STATUS_RUNNING} + + with pytest.raises(TimeoutError): + queued_job.result(1) diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py index 4c93e18939..5d32030580 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py @@ -75,12 +75,12 @@ def test_constructor_node_should_be_modified(src, expected): ("sagemaker.predictor._NumpyDeserializer()", "deserializers.NumpyDeserializer()"), ("sagemaker.predictor._JsonDeserializer()", "deserializers.JSONDeserializer()"), ( - "sagemaker.amazon.common.numpy_to_record_serializer()", - "sagemaker.amazon.common.RecordSerializer()", + "sagemaker.serializers.numpy_to_record_serializer()", + "sagemaker.serializers.RecordSerializer()", ), ( - "sagemaker.amazon.common.record_deserializer()", - "sagemaker.amazon.common.RecordDeserializer()", + "sagemaker.deserializers.record_deserializer()", + "sagemaker.deserializers.RecordDeserializer()", ), ("_CsvSerializer()", "serializers.CSVSerializer()"), ("_JsonSerializer()", "serializers.JSONSerializer()"), @@ -265,20 +265,12 @@ def test_import_from_amazon_common_node_should_be_modified(import_statement, exp "import_statement, expected", [ ( - "from sagemaker.amazon.common import numpy_to_record_serializer", - "from sagemaker.amazon.common import RecordSerializer", + "from sagemaker.serializers import numpy_to_record_serializer", + "from sagemaker.serializers import RecordSerializer", ), ( - "from sagemaker.amazon.common import record_deserializer", - "from sagemaker.amazon.common import RecordDeserializer", - ), - ( - "from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer", - "from sagemaker.amazon.common import RecordSerializer, RecordDeserializer", - ), - ( - "from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor, numpy_to_record_serializer", - "from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor, RecordSerializer", + "from sagemaker.deserializers import record_deserializer", + "from sagemaker.deserializers import RecordDeserializer", ), ], ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index d116c8121b..dda1e30db2 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -56,6 +56,8 @@ def test_jumpstart_default_content_types( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) @@ -75,15 +77,12 @@ def test_jumpstart_supported_content_types( model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" - supported_content_types = content_types.retrieve_options( + content_types.retrieve_options( region=region, model_id=model_id, model_version=model_version, sagemaker_session=mock_session, ) - assert supported_content_types == [ - "application/x-text", - ] patched_get_model_specs.assert_called_once_with( region=region, @@ -91,4 +90,6 @@ def test_jumpstart_supported_content_types( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index f0102068e7..9bbca51654 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -58,6 +58,8 @@ def test_jumpstart_default_deserializers( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) @@ -98,4 +100,6 @@ def test_jumpstart_deserializer_options( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index 5f00f93abf..13f720870c 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -61,6 +61,8 @@ def test_jumpstart_default_environment_variables( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -85,6 +87,8 @@ def test_jumpstart_default_environment_variables( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -147,6 +151,8 @@ def test_jumpstart_sdk_environment_variables( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -172,6 +178,8 @@ def test_jumpstart_sdk_environment_variables( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py b/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py index 118800dd0f..f149823b2f 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/lineage/test_feature_processor_lineage.py @@ -113,69 +113,85 @@ def test_create_lineage_when_no_lineage_exists_with_fg_only(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - side_effect=RESOURCE_NOT_FOUND_EXCEPTION, - ) as load_pipeline_context_method, patch.object( - PipelineLineageEntityHandler, - "create_pipeline_context", - return_value=PIPELINE_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - [], - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + ): lineage_handler.create_lineage() retrieve_feature_group_context_arns_method.assert_has_calls( @@ -259,75 +275,92 @@ def test_create_lineage_when_no_lineage_exists_with_raw_data_only(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - side_effect=RESOURCE_NOT_FOUND_EXCEPTION, - ) as load_pipeline_context_method, patch.object( - PipelineLineageEntityHandler, - "create_pipeline_context", - return_value=PIPELINE_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - [], - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_called_once_with( @@ -408,75 +441,92 @@ def test_create_lineage_when_no_lineage_exists_with_fg_and_raw_data_with_tags(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - side_effect=RESOURCE_NOT_FOUND_EXCEPTION, - ) as load_pipeline_context_method, patch.object( - PipelineLineageEntityHandler, - "create_pipeline_context", - return_value=PIPELINE_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - [], - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -569,75 +619,92 @@ def test_create_lineage_when_no_lineage_exists_with_no_transformation_code(): output=FEATURE_GROUP_DATA_SOURCE[0].name, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=None, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - side_effect=RESOURCE_NOT_FOUND_EXCEPTION, - ) as load_pipeline_context_method, patch.object( - PipelineLineageEntityHandler, - "create_pipeline_context", - return_value=PIPELINE_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - [], - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=None, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -728,78 +795,96 @@ def test_create_lineage_when_already_exist_with_no_version_change(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=PIPELINE_CONTEXT, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as create_pipeline_version_context_method, patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as create_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -925,73 +1010,91 @@ def test_create_lineage_when_already_exist_with_changed_raw_data(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1]], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1]], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -1140,74 +1243,92 @@ def test_create_lineage_when_already_exist_with_changed_input_fg(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[FEATURE_GROUP_INPUT[0], FEATURE_GROUP_INPUT[0]], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[FEATURE_GROUP_INPUT[0], FEATURE_GROUP_INPUT[0]], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -1354,78 +1475,96 @@ def test_create_lineage_when_already_exist_with_changed_output_fg(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[1], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_1, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[1], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -1576,78 +1715,96 @@ def test_create_lineage_when_already_exist_with_changed_transformation_code(): transformation_code=TRANSFORMATION_CODE_INPUT_2, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_2, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -1778,78 +1935,96 @@ def test_create_lineage_when_already_exist_with_last_transformation_code_as_none transformation_code=TRANSFORMATION_CODE_INPUT_2, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_2, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -1968,77 +2143,95 @@ def test_create_lineage_when_already_exist_with_all_previous_transformation_code transformation_code=TRANSFORMATION_CODE_INPUT_2, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=TRANSFORMATION_CODE_ARTIFACT_2, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - iter([]), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + iter([]), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -2154,78 +2347,96 @@ def test_create_lineage_when_already_exist_with_removed_transformation_code(): output=FEATURE_GROUP_DATA_SOURCE[0].name, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - FeatureGroupLineageEntityHandler, - "retrieve_feature_group_context_arns", - side_effect=[ - FEATURE_GROUP_INPUT[0], - FEATURE_GROUP_INPUT[1], - FEATURE_GROUP_INPUT[0], - ], - ) as retrieve_feature_group_context_arns_method, patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - S3LineageEntityHandler, - "create_transformation_code_artifact", - return_value=None, - ) as create_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - generate_pipeline_version_upstream_transformation_code(), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, - "load_artifact_from_arn", - return_value=transformation_code_1, - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, - "update_transformation_code_artifact", - ) as update_transformation_code_artifact_method, patch.object( - PipelineLineageEntityHandler, - "update_pipeline_context", - ) as update_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "create_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ), patch.object( - LineageAssociationHandler, "add_upstream_feature_group_data_associations" - ) as add_upstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_downstream_feature_group_data_associations" - ) as add_downstream_feature_group_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_raw_data_associations" - ) as add_upstream_raw_data_associations_method, patch.object( - LineageAssociationHandler, "add_upstream_transformation_code_associations" - ) as add_upstream_transformation_code_associations_method, patch.object( - LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" - ) as add_pipeline_and_pipeline_version_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=None, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_lineage(TAGS) retrieve_feature_group_context_arns_method.assert_has_calls( @@ -2370,15 +2581,18 @@ def test_get_pipeline_lineage_names_when_lineage_exists(): transformation_code=TRANSFORMATION_CODE_INPUT_1, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=PIPELINE_CONTEXT, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method: + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + ): return_value = lineage_handler.get_pipeline_lineage_names() assert return_value == dict( @@ -2416,28 +2630,34 @@ def test_create_schedule_lineage(): pipeline=PIPELINE, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=PIPELINE_CONTEXT, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - S3LineageEntityHandler, - "retrieve_pipeline_schedule_artifact", - return_value=SCHEDULE_ARTIFACT_RESULT, - ) as retrieve_pipeline_schedule_artifact_method, patch.object( - LineageAssociationHandler, - "add_upstream_schedule_associations", - ) as add_upstream_schedule_associations_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + S3LineageEntityHandler, + "retrieve_pipeline_schedule_artifact", + return_value=SCHEDULE_ARTIFACT_RESULT, + ) as retrieve_pipeline_schedule_artifact_method, + patch.object( + LineageAssociationHandler, + "add_upstream_schedule_associations", + ) as add_upstream_schedule_associations_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_schedule_lineage( pipeline_name=PIPELINE_NAME, schedule_arn=SCHEDULE_ARN, @@ -2487,28 +2707,34 @@ def test_create_trigger_lineage(): pipeline=PIPELINE, sagemaker_session=SAGEMAKER_SESSION_MOCK, ) - with patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=PIPELINE_CONTEXT, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - S3LineageEntityHandler, - "retrieve_pipeline_trigger_artifact", - return_value=PIPELINE_TRIGGER_ARTIFACT, - ) as retrieve_pipeline_trigger_artifact_method, patch.object( - LineageAssociationHandler, - "_add_association", - ) as add_association_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags: + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + S3LineageEntityHandler, + "retrieve_pipeline_trigger_artifact", + return_value=PIPELINE_TRIGGER_ARTIFACT, + ) as retrieve_pipeline_trigger_artifact_method, + patch.object( + LineageAssociationHandler, + "_add_association", + ) as add_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): lineage_handler.create_trigger_lineage( pipeline_name=PIPELINE_NAME, trigger_arn=TRIGGER_ARN, @@ -2564,56 +2790,68 @@ def test_upsert_tags_for_lineage_resources(): ) lineage_handler.sagemaker_session.boto_session = Mock() lineage_handler.sagemaker_session.sagemaker_client = Mock() - with patch.object( - S3LineageEntityHandler, - "retrieve_raw_data_artifact", - side_effect=[ - RAW_DATA_INPUT_ARTIFACTS[0], - RAW_DATA_INPUT_ARTIFACTS[1], - RAW_DATA_INPUT_ARTIFACTS[2], - RAW_DATA_INPUT_ARTIFACTS[3], - ], - ) as retrieve_raw_data_artifact_method, patch.object( - PipelineLineageEntityHandler, - "load_pipeline_context", - return_value=pipeline_context, - ) as load_pipeline_context_method, patch.object( - PipelineVersionLineageEntityHandler, - "load_pipeline_version_context", - return_value=PIPELINE_VERSION_CONTEXT, - ) as load_pipeline_version_context_method, patch.object( - LineageAssociationHandler, - "list_upstream_associations", - side_effect=[ - generate_pipeline_version_upstream_feature_group_list(), - generate_pipeline_version_upstream_raw_data_list(), - iter([]), - ], - ) as list_upstream_associations_method, patch.object( - LineageAssociationHandler, - "list_downstream_associations", - return_value=generate_pipeline_version_downstream_feature_group(), - ) as list_downstream_associations_method, patch.object( - S3LineageEntityHandler, "load_artifact_from_arn", return_value=ARTIFACT_RESULT - ) as load_artifact_from_arn_method, patch.object( - S3LineageEntityHandler, "_load_artifact_from_s3_uri", return_value=ARTIFACT_SUMMARY - ) as load_artifact_from_s3_uri_method, patch.object( - Artifact, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as artifact_set_tags, patch.object( - Context, - "set_tags", - return_value={ - "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] - }, - ) as context_set_tags, patch.object( - EventBridgeSchedulerHelper, "describe_schedule", return_value=dict(Arn="schedule_arn") - ) as get_event_bridge_schedule, patch.object( - EventBridgeRuleHelper, "describe_rule", return_value=dict(Arn="rule_arn") - ) as get_event_bridge_rule: + with ( + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + iter([]), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, "load_artifact_from_arn", return_value=ARTIFACT_RESULT + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, "_load_artifact_from_s3_uri", return_value=ARTIFACT_SUMMARY + ) as load_artifact_from_s3_uri_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + patch.object( + Context, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as context_set_tags, + patch.object( + EventBridgeSchedulerHelper, "describe_schedule", return_value=dict(Arn="schedule_arn") + ) as get_event_bridge_schedule, + patch.object( + EventBridgeRuleHelper, "describe_rule", return_value=dict(Arn="rule_arn") + ) as get_event_bridge_rule, + ): lineage_handler.upsert_tags_for_lineage_resources(TAGS) retrieve_raw_data_artifact_method.assert_has_calls( diff --git a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py index 9020a9f05f..7b35174940 100644 --- a/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py +++ b/tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py @@ -907,6 +907,10 @@ def test_remote_decorator_fields_consistency(get_execution_role, session): "use_spot_instances", "max_wait_time_in_seconds", "custom_file_filter", + "disable_output_compression", + "use_torchrun", + "use_mpirun", + "nproc_per_node", } job_settings = _JobSettings( diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 3ad641a321..0eee116e5d 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -241,7 +241,7 @@ def test_huggingface( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_version, f"pytorch{huggingface_pytorch_training_version}" diff --git a/tests/unit/sagemaker/huggingface/test_llm_utils.py b/tests/unit/sagemaker/huggingface/test_llm_utils.py index 3c4cdde3f6..9bb1b451a1 100644 --- a/tests/unit/sagemaker/huggingface/test_llm_utils.py +++ b/tests/unit/sagemaker/huggingface/test_llm_utils.py @@ -15,7 +15,10 @@ from unittest import TestCase from urllib.error import HTTPError from unittest.mock import Mock, patch -from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata +from sagemaker.huggingface.llm_utils import ( + get_huggingface_model_metadata, + download_huggingface_model_metadata, +) MOCK_HF_ID = "mock_hf_id" MOCK_HF_HUB_TOKEN = "mock_hf_hub_token" @@ -62,7 +65,7 @@ def test_huggingface_model_metadata_unauthorized_exception(self, mock_urllib): "Trying to access a gated/private HuggingFace model without valid credentials. " "Please provide a HUGGING_FACE_HUB_TOKEN in env_vars" ) - self.assertEquals(expected_error_msg, str(context.exception)) + self.assertEqual(expected_error_msg, str(context.exception)) @patch("sagemaker.huggingface.llm_utils.urllib") def test_huggingface_model_metadata_general_exception(self, mock_urllib): @@ -73,4 +76,26 @@ def test_huggingface_model_metadata_general_exception(self, mock_urllib): expected_error_msg = ( f"Did not find model metadata for the following HuggingFace Model ID {MOCK_HF_ID}" ) - self.assertEquals(expected_error_msg, str(context.exception)) + self.assertEqual(expected_error_msg, str(context.exception)) + + @patch("huggingface_hub.snapshot_download") + def test_download_huggingface_model_metadata(self, mock_snapshot_download): + mock_snapshot_download.side_effect = None + + download_huggingface_model_metadata(MOCK_HF_ID, "local_path", MOCK_HF_HUB_TOKEN) + + mock_snapshot_download.assert_called_once_with( + repo_id=MOCK_HF_ID, local_dir="local_path", token=MOCK_HF_HUB_TOKEN + ) + + @patch("importlib.util.find_spec") + def test_download_huggingface_model_metadata_ex(self, mock_find_spec): + mock_find_spec.side_effect = lambda *args, **kwargs: False + + self.assertRaisesRegex( + ImportError, + "Unable to import huggingface_hub, check if huggingface_hub is installed", + lambda: download_huggingface_model_metadata( + MOCK_HF_ID, "local_path", MOCK_HF_HUB_TOKEN + ), + ) diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index 40ee4978cf..d0915b8881 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -46,7 +46,13 @@ def test_jumpstart_default_hyperparameters( model_version="*", sagemaker_session=mock_session, ) - assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} + assert params == { + "train_only_top_layer": "True", + "epochs": "5", + "learning_rate": "0.001", + "batch_size": "4", + "reinitialize_top_layer": "Auto", + } patched_get_model_specs.assert_called_once_with( region=region, @@ -54,6 +60,8 @@ def test_jumpstart_default_hyperparameters( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -64,7 +72,13 @@ def test_jumpstart_default_hyperparameters( model_version="1.*", sagemaker_session=mock_session, ) - assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} + assert params == { + "train_only_top_layer": "True", + "epochs": "5", + "learning_rate": "0.001", + "batch_size": "4", + "reinitialize_top_layer": "Auto", + } patched_get_model_specs.assert_called_once_with( region=region, @@ -72,6 +86,8 @@ def test_jumpstart_default_hyperparameters( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -84,12 +100,14 @@ def test_jumpstart_default_hyperparameters( sagemaker_session=mock_session, ) assert params == { - "adam-learning-rate": "0.05", - "batch-size": "4", - "epochs": "3", - "sagemaker_container_log_level": "20", - "sagemaker_program": "transfer_learning.py", + "train_only_top_layer": "True", + "epochs": "5", + "learning_rate": "0.001", + "batch_size": "4", + "reinitialize_top_layer": "Auto", "sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz", + "sagemaker_program": "transfer_learning.py", + "sagemaker_container_log_level": "20", } patched_get_model_specs.assert_called_once_with( @@ -98,6 +116,8 @@ def test_jumpstart_default_hyperparameters( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index fdc29b4d90..af5413ce6b 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -21,7 +21,7 @@ from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError from sagemaker.jumpstart.types import JumpStartHyperparameter -from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec region = "us-west-2" mock_client = boto3.client("s3") @@ -34,7 +34,7 @@ def test_jumpstart_validate_provided_hyperparameters( patched_get_model_specs, patched_validate_model_id_and_get_type ): def add_options_to_hyperparameter(*largs, **kwargs): - spec = get_spec_from_base_spec(*largs, **kwargs) + spec = get_prototype_model_spec(*largs, **kwargs) spec.hyperparameters.extend( [ JumpStartHyperparameter( @@ -115,7 +115,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): patched_get_model_specs.side_effect = add_options_to_hyperparameter patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + model_id, model_version = "mxnet-semseg-fcn-resnet50-ade", "*" region = "us-west-2" hyperparameter_to_test = { @@ -146,6 +146,8 @@ def add_options_to_hyperparameter(*largs, **kwargs): version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -410,7 +412,7 @@ def test_jumpstart_validate_algorithm_hyperparameters( patched_get_model_specs, patched_validate_model_id_and_get_type ): def add_options_to_hyperparameter(*largs, **kwargs): - spec = get_spec_from_base_spec(*largs, **kwargs) + spec = get_prototype_model_spec(*largs, **kwargs) spec.hyperparameters.append( JumpStartHyperparameter( { @@ -427,10 +429,11 @@ def add_options_to_hyperparameter(*largs, **kwargs): patched_get_model_specs.side_effect = add_options_to_hyperparameter patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + model_id, model_version = "mxnet-semseg-fcn-resnet50-ade", "*" region = "us-west-2" hyperparameter_to_test = { + "train-only-top-layer": "True", "adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3", @@ -452,6 +455,8 @@ def add_options_to_hyperparameter(*largs, **kwargs): version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -484,13 +489,14 @@ def test_jumpstart_validate_all_hyperparameters( patched_get_model_specs, patched_validate_model_id_and_get_type ): - patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_get_model_specs.side_effect = get_prototype_model_spec patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + model_id, model_version = "mxnet-semseg-fcn-resnet50-ade", "*" region = "us-west-2" hyperparameter_to_test = { + "train-only-top-layer": "True", "adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3", @@ -514,6 +520,8 @@ def test_jumpstart_validate_all_hyperparameters( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/expected_uris.py b/tests/unit/sagemaker/image_uris/expected_uris.py index 094323ef0b..eb198454fc 100644 --- a/tests/unit/sagemaker/image_uris/expected_uris.py +++ b/tests/unit/sagemaker/image_uris/expected_uris.py @@ -12,12 +12,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -ALTERNATE_DOMAINS = { - "cn-north-1": "amazonaws.com.cn", - "cn-northwest-1": "amazonaws.com.cn", - "us-iso-east-1": "c2s.ic.gov", - "us-isob-east-1": "sc2s.sgov.gov", -} +from sagemaker.utils import ALTERNATE_DOMAINS + DOMAIN = "amazonaws.com" IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:{}" MONITOR_URI_FORMAT = "{}.dkr.ecr.{}.{}/sagemaker-model-monitor-analyzer" @@ -111,3 +107,12 @@ def base_python_uri(repo, account, region=REGION): domain = ALTERNATE_DOMAINS.get(region, DOMAIN) tag = "1.0" return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) + + +def sagemaker_distribution_uri(repo, account, tag, processor, region=REGION): + domain = ALTERNATE_DOMAINS.get(region, DOMAIN) + if processor == "cpu": + tag = f"{tag}-cpu" + else: + tag = f"{tag}-gpu" + return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py b/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py index 9261fd561e..98ddefb2c2 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py @@ -30,7 +30,7 @@ def test_jumpstart_catboost_image_uri(patched_get_model_specs, session): patched_get_model_specs.side_effect = get_prototype_model_spec model_id, model_version = "catboost-classification-model", "*" - instance_type = "ml.p2.xlarge" + instance_type = "ml.m5.xlarge" region = "us-west-2" model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) @@ -55,7 +55,7 @@ def test_jumpstart_catboost_image_uri(patched_get_model_specs, session): ).serving_image_uri(region, instance_type) assert uri == framework_class_uri - assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38" + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310" # training uri = image_uris.retrieve( @@ -78,4 +78,4 @@ def test_jumpstart_catboost_image_uri(patched_get_model_specs, session): ).training_image_uri(region=region) assert uri == framework_class_uri - assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38" + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38" diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 88b95b9403..9a4febd0e4 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -47,15 +47,17 @@ def test_jumpstart_common_image_uri( image_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -68,15 +70,17 @@ def test_jumpstart_common_image_uri( image_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -89,15 +93,17 @@ def test_jumpstart_common_image_uri( image_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -110,15 +116,17 @@ def test_jumpstart_common_image_uri( image_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -129,7 +137,7 @@ def test_jumpstart_common_image_uri( image_scope="BAD_SCOPE", model_id="pytorch-ic-mobilenet-v2", model_version="*", - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", ) with pytest.raises(KeyError): @@ -139,7 +147,7 @@ def test_jumpstart_common_image_uri( image_scope="training", model_id="blah", model_version="*", - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", ) with pytest.raises(ValueError): @@ -149,7 +157,7 @@ def test_jumpstart_common_image_uri( image_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", ) with pytest.raises(ValueError): @@ -158,7 +166,7 @@ def test_jumpstart_common_image_uri( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", model_version="*", - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", ) with pytest.raises(ValueError): @@ -167,7 +175,7 @@ def test_jumpstart_common_image_uri( region="us-west-2", image_scope="training", model_version="*", - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", ) with pytest.raises(ValueError): @@ -176,5 +184,56 @@ def test_jumpstart_common_image_uri( framework=None, image_scope="training", model_id="pytorch-ic-mobilenet-v2", - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", ) + + +@patch("sagemaker.image_uris.JUMPSTART_LOGGER.info") +@patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") +@patch("sagemaker.jumpstart.artifacts.image_uris.verify_model_region_and_return_specs") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_image_uri_logging_extra_fields( + patched_get_model_specs, + patched_verify_model_region_and_return_specs, + patched_validate_model_id_and_get_type, + patched_info_log, +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs + patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + + region = "us-west-2" + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client, boto_region_name=region) + + image_uris.retrieve( + framework=None, + region="us-west-2", + image_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + patched_info_log.assert_not_called() + + image_uris.retrieve( + framework="framework", + container_version="1.2.3", + region="us-west-2", + image_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + patched_info_log.assert_called_once_with( + "Ignoring the following arguments " + "when retrieving image uri for " + "JumpStart model id '%s': %s", + "pytorch-ic-mobilenet-v2", + "{'framework': 'framework', 'container_version': '1.2.3'}", + ) diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py b/tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py index 1ce213cd27..0f426103e8 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py @@ -28,7 +28,8 @@ def test_jumpstart_huggingface_image_uri(patched_get_model_specs, session): patched_get_model_specs.side_effect = get_prototype_model_spec model_id, model_version = "huggingface-spc-bert-base-cased", "*" - instance_type = "ml.p2.xlarge" + instance_type = "ml.m5.xlarge" + training_instance_type = "ml.p3.2xlarge" region = "us-west-2" model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) @@ -55,7 +56,7 @@ def test_jumpstart_huggingface_image_uri(patched_get_model_specs, session): assert ( uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:" - "1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04" + "1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04" ) # training @@ -65,7 +66,7 @@ def test_jumpstart_huggingface_image_uri(patched_get_model_specs, session): image_scope="training", model_id=model_id, model_version=model_version, - instance_type=instance_type, + instance_type=training_instance_type, ) framework_class_uri = HuggingFace( @@ -75,7 +76,7 @@ def test_jumpstart_huggingface_image_uri(patched_get_model_specs, session): entry_point="some_entry_point", transformers_version=model_specs.training_ecr_specs.huggingface_transformers_version, pytorch_version=model_specs.training_ecr_specs.framework_version, - instance_type=instance_type, + instance_type=training_instance_type, instance_count=1, sagemaker_session=session, ).training_image_uri(region=region) diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py b/tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py index e907a19b51..159d801867 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py @@ -28,7 +28,7 @@ def test_jumpstart_lightgbm_image_uri(patched_get_model_specs, session): patched_get_model_specs.side_effect = get_prototype_model_spec model_id, model_version = "lightgbm-classification-model", "*" - instance_type = "ml.p2.xlarge" + instance_type = "ml.m5.xlarge" region = "us-west-2" model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) @@ -53,7 +53,7 @@ def test_jumpstart_lightgbm_image_uri(patched_get_model_specs, session): ).serving_image_uri(region, instance_type) assert uri == framework_class_uri - assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38" + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310" # training uri = image_uris.retrieve( @@ -76,4 +76,4 @@ def test_jumpstart_lightgbm_image_uri(patched_get_model_specs, session): ).training_image_uri(region=region) assert uri == framework_class_uri - assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38" + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38" diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py b/tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py index 9fd09d47d9..4d09c36b68 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py @@ -28,7 +28,7 @@ def test_jumpstart_mxnet_image_uri(patched_get_model_specs, session): patched_get_model_specs.side_effect = get_prototype_model_spec model_id, model_version = "mxnet-semseg-fcn-resnet50-ade", "*" - instance_type = "ml.p2.xlarge" + instance_type = "ml.m5.xlarge" region = "us-west-2" model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) @@ -53,7 +53,7 @@ def test_jumpstart_mxnet_image_uri(patched_get_model_specs, session): ).serving_image_uri(region, instance_type) assert uri == framework_class_uri - assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.7.0-gpu-py3" + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.9.0-cpu-py38" # training uri = image_uris.retrieve( @@ -76,4 +76,4 @@ def test_jumpstart_mxnet_image_uri(patched_get_model_specs, session): ).training_image_uri(region=region) assert uri == framework_class_uri - assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:1.7.0-gpu-py3" + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:1.9.0-cpu-py38" diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py b/tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py index a94801da10..24e759e6bc 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py @@ -27,7 +27,7 @@ def test_jumpstart_pytorch_image_uri(patched_get_model_specs, session): patched_get_model_specs.side_effect = get_prototype_model_spec - model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + model_id, model_version = "pytorch-ic-mobilenet-v2", "*" instance_type = "ml.p2.xlarge" region = "us-west-2" @@ -53,7 +53,7 @@ def test_jumpstart_pytorch_image_uri(patched_get_model_specs, session): ).serving_image_uri(region, instance_type) assert uri == framework_class_uri - assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3" + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38" # training uri = image_uris.retrieve( @@ -76,4 +76,4 @@ def test_jumpstart_pytorch_image_uri(patched_get_model_specs, session): ).training_image_uri(region=region) assert uri == framework_class_uri - assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3" + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38" diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py b/tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py index 1410c59bb6..af0614c465 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py @@ -29,7 +29,7 @@ def test_jumpstart_sklearn_image_uri(patched_get_model_specs, session): patched_get_model_specs.side_effect = get_prototype_model_spec model_id, model_version = "sklearn-classification-linear", "*" - instance_type = "ml.m2.xlarge" + instance_type = "ml.m5.xlarge" region = "us-west-2" model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version) @@ -53,9 +53,15 @@ def test_jumpstart_sklearn_image_uri(patched_get_model_specs, session): sagemaker_session=session, ).serving_image_uri(region, instance_type) - assert uri == framework_class_uri + # framework classes dont use digest. assert ( - uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3" + framework_class_uri + == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:1.2-1" + "-cpu-py3" + ) + assert ( + uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn@" + "sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" ) # training @@ -79,9 +85,14 @@ def test_jumpstart_sklearn_image_uri(patched_get_model_specs, session): sagemaker_session=session, ).training_image_uri(region=region) - assert uri == framework_class_uri + # framework classes dont use digest. + assert ( + framework_class_uri + == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:1.2-1-cpu-py3" + ) assert ( - uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3" + uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn" + "@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" ) with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py b/tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py index c924615212..cfe57f4053 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py @@ -52,7 +52,7 @@ def test_jumpstart_tensorflow_image_uri(patched_get_model_specs, session): ).serving_image_uri(region, instance_type) assert uri == framework_class_uri - assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:2.3-gpu" + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:2.8-gpu" # training uri = image_uris.retrieve( @@ -75,4 +75,4 @@ def test_jumpstart_tensorflow_image_uri(patched_get_model_specs, session): ).training_image_uri(region=region) assert uri == framework_class_uri - assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.3-gpu-py37" + assert uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.9-gpu-py39" diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_variants.py b/tests/unit/sagemaker/image_uris/jumpstart/test_variants.py index 20547caca3..80bf54f722 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_variants.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_variants.py @@ -52,9 +52,8 @@ def test_jumpstart_variants_image_uri( instance_type="ml.c2.xlarge", ) - assert ( - "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-cpu-py3" - == image_uris.retrieve( + with pytest.raises(ValueError): + image_uris.retrieve( framework=None, region="us-west-2", image_scope="inference", @@ -62,7 +61,6 @@ def test_jumpstart_variants_image_uri( model_version="*", instance_type="ml.c200000.xlarge", ) - ) with pytest.raises(ValueError): image_uris.retrieve( @@ -74,9 +72,8 @@ def test_jumpstart_variants_image_uri( instance_type="ml.c2.xlarge", ) - assert ( - "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3" - == image_uris.retrieve( + with pytest.raises(ValueError): + image_uris.retrieve( framework=None, region="us-west-2", image_scope="training", @@ -84,4 +81,3 @@ def test_jumpstart_variants_image_uri( model_version="*", instance_type="ml.g4dn.2xlarge", ) - ) diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py b/tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py index 5da3b71176..7817980bac 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py @@ -52,8 +52,15 @@ def test_jumpstart_xgboost_image_uri(patched_get_model_specs, session): sagemaker_session=session, ).serving_image_uri(region, instance_type) - assert uri == framework_class_uri - assert uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.3-1" + # framework classes dont use digest + assert ( + framework_class_uri + == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.7-1" + ) + assert ( + uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost@sha256:" + "ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + ) # training uri = image_uris.retrieve( @@ -76,5 +83,12 @@ def test_jumpstart_xgboost_image_uri(patched_get_model_specs, session): sagemaker_session=session, ).training_image_uri(region=region) - assert uri == framework_class_uri - assert uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.3-1" + # framework classes dont use digest + assert ( + framework_class_uri + == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:1.7-1" + ) + assert ( + uri == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost" + "@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + ) diff --git a/tests/unit/sagemaker/image_uris/test_djl.py b/tests/unit/sagemaker/image_uris/test_djl.py index 6457fe044f..887b575fdf 100644 --- a/tests/unit/sagemaker/image_uris/test_djl.py +++ b/tests/unit/sagemaker/image_uris/test_djl.py @@ -18,12 +18,7 @@ @pytest.mark.parametrize( "load_config_and_file_name", - [ - "djl-neuronx.json", - "djl-fastertransformer.json", - "djl-deepspeed.json", - "djl-tensorrtllm.json", - ], + ["djl-neuronx.json", "djl-tensorrtllm.json", "djl-lmi.json"], indirect=True, ) def test_djl_uris(load_config_and_file_name): diff --git a/tests/unit/sagemaker/image_uris/test_graviton.py b/tests/unit/sagemaker/image_uris/test_graviton.py index ea4ef29919..a122be9291 100644 --- a/tests/unit/sagemaker/image_uris/test_graviton.py +++ b/tests/unit/sagemaker/image_uris/test_graviton.py @@ -30,11 +30,18 @@ ] -def _test_graviton_framework_uris(framework, version, py_version, account, region): +def _test_graviton_framework_uris( + framework, version, py_version, account, region, container_version="ubuntu20.04-sagemaker" +): for instance_type in GRAVITON_INSTANCE_TYPES: uri = image_uris.retrieve(framework, region, instance_type=instance_type, version=version) expected = _expected_graviton_framework_uri( - framework, version, py_version, account, region=region + framework, + version, + py_version, + account, + region=region, + container_version=container_version, ) assert expected == uri @@ -50,11 +57,21 @@ def test_graviton_framework_uris(load_config_and_file_name, scope): for version in VERSIONS: ACCOUNTS = config[scope]["versions"][version]["registries"] py_versions = config[scope]["versions"][version]["py_versions"] + container_version = ( + config[scope]["versions"][version].get("container_version", {}).get("cpu", None) + ) + if container_version: + container_version = container_version + "-sagemaker" for py_version in py_versions: for region in ACCOUNTS.keys(): - _test_graviton_framework_uris( - framework, version, py_version, ACCOUNTS[region], region - ) + if container_version: + _test_graviton_framework_uris( + framework, version, py_version, ACCOUNTS[region], region, container_version + ) + else: + _test_graviton_framework_uris( + framework, version, py_version, ACCOUNTS[region], region + ) def _test_graviton_unsupported_framework(framework, region, framework_version): @@ -183,11 +200,14 @@ def test_graviton_sklearn_image_scope_specified_x86_instance(graviton_sklearn_un assert "Unsupported instance type: m5." in str(error) -def _expected_graviton_framework_uri(framework, version, py_version, account, region): +def _expected_graviton_framework_uri( + framework, version, py_version, account, region, container_version +): return expected_uris.graviton_framework_uri( "{}-inference-graviton".format(framework), fw_version=version, py_version=py_version, account=account, region=region, + container_version=container_version, ) diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index 582e5cf82d..e693b9f8ce 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -13,11 +13,26 @@ from __future__ import absolute_import import pytest +from packaging.version import parse from sagemaker.huggingface import get_huggingface_llm_image_uri from tests.unit.sagemaker.image_uris import expected_uris, conftest LMI_VERSIONS = ["0.24.0"] +TEI_VERSIONS_MAPPING = { + "gpu": { + "1.2.3": "2.0.1-tei1.2.3-gpu-py310-cu122-ubuntu22.04", + "1.4.0": "2.0.1-tei1.4.0-gpu-py310-cu122-ubuntu22.04", + "1.6.0": "2.0.1-tei1.6.0-gpu-py310-cu122-ubuntu22.04", + "1.7.0": "2.0.1-tei1.7.0-gpu-py310-cu122-ubuntu22.04", + }, + "cpu": { + "1.2.3": "2.0.1-tei1.2.3-cpu-py310-ubuntu22.04", + "1.4.0": "2.0.1-tei1.4.0-cpu-py310-ubuntu22.04", + "1.6.0": "2.0.1-tei1.6.0-cpu-py310-ubuntu22.04", + "1.7.0": "2.0.1-tei1.7.0-cpu-py310-ubuntu22.04", + }, +} HF_VERSIONS_MAPPING = { "gpu": { "0.6.0": "2.0.0-tgi0.6.0-gpu-py39-cu118-ubuntu20.04", @@ -32,6 +47,12 @@ "1.4.2": "2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04", "1.4.5": "2.1.1-tgi1.4.5-gpu-py310-cu121-ubuntu22.04", "2.0.0": "2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04", + "2.0.1": "2.1.1-tgi2.0.1-gpu-py310-cu121-ubuntu22.04", + "2.0.2": "2.3.0-tgi2.0.2-gpu-py310-cu121-ubuntu22.04", + "2.2.0": "2.3.0-tgi2.2.0-gpu-py310-cu121-ubuntu22.04-v2.0", + "2.3.1": "2.4.0-tgi2.3.1-gpu-py311-cu124-ubuntu22.04", + "2.4.0": "2.4.0-tgi2.4.0-gpu-py311-cu124-ubuntu22.04-v2.2", + "3.0.1": "2.4.0-tgi3.0.1-gpu-py311-cu124-ubuntu22.04-v2.1", }, "inf2": { "0.0.16": "1.13.1-optimum0.0.16-neuronx-py310-ubuntu22.04", @@ -40,6 +61,11 @@ "0.0.19": "1.13.1-optimum0.0.19-neuronx-py310-ubuntu22.04", "0.0.20": "1.13.1-optimum0.0.20-neuronx-py310-ubuntu22.04", "0.0.21": "1.13.1-optimum0.0.21-neuronx-py310-ubuntu22.04", + "0.0.22": "2.1.2-optimum0.0.22-neuronx-py310-ubuntu22.04", + "0.0.23": "2.1.2-optimum0.0.23-neuronx-py310-ubuntu22.04", + "0.0.24": "2.1.2-optimum0.0.24-neuronx-py310-ubuntu22.04", + "0.0.25": "2.1.2-optimum0.0.25-neuronx-py310-ubuntu22.04", + "0.0.27": "2.1.2-optimum0.0.27-neuronx-py310-ubuntu22.04", }, } @@ -51,10 +77,31 @@ def test_huggingface_uris(load_config): VERSIONS = load_config["inference"]["versions"] device = load_config["inference"]["processors"][0] backend = "huggingface-neuronx" if device == "inf2" else "huggingface" + + # Fail if device is not in mapping + if device not in HF_VERSIONS_MAPPING: + raise ValueError(f"Device {device} not found in HF_VERSIONS_MAPPING") + + # Get highest version for the device + highest_version = max(HF_VERSIONS_MAPPING[device].keys(), key=lambda x: parse(x)) + for version in VERSIONS: ACCOUNTS = load_config["inference"]["versions"][version]["registries"] for region in ACCOUNTS.keys(): uri = get_huggingface_llm_image_uri(backend, region=region, version=version) + + # Skip only if test version is higher than highest known version. + # There's now automation to add new TGI releases to image_uri_config directory + # that doesn't involve a human raising a PR. + if parse(version) > parse(highest_version): + print( + f"Skipping version check for {version} as there is " + "automation that now updates the image_uri_config " + "without a human raising a PR. Tests will pass for " + f"versions higher than {highest_version} that are not in HF_VERSIONS_MAPPING." + ) + continue + expected = expected_uris.huggingface_llm_framework_uri( "huggingface-pytorch-tgi-inference", ACCOUNTS[region], @@ -65,6 +112,28 @@ def test_huggingface_uris(load_config): assert expected == uri +@pytest.mark.parametrize( + "load_config", ["huggingface-tei.json", "huggingface-tei-cpu.json"], indirect=True +) +def test_huggingface_tei_uris(load_config): + VERSIONS = load_config["inference"]["versions"] + device = load_config["inference"]["processors"][0] + backend = "huggingface-tei" if device == "gpu" else "huggingface-tei-cpu" + repo = "tei" if device == "gpu" else "tei-cpu" + for version in VERSIONS: + ACCOUNTS = load_config["inference"]["versions"][version]["registries"] + for region in ACCOUNTS.keys(): + uri = get_huggingface_llm_image_uri(backend, region=region, version=version) + expected = expected_uris.huggingface_llm_framework_uri( + repo, + ACCOUNTS[region], + version, + TEI_VERSIONS_MAPPING[device][version], + region=region, + ) + assert expected == uri + + @pytest.mark.parametrize("load_config", ["huggingface-llm.json"], indirect=True) def test_lmi_uris(load_config): VERSIONS = load_config["inference"]["versions"] diff --git a/tests/unit/sagemaker/image_uris/test_retrieve.py b/tests/unit/sagemaker/image_uris/test_retrieve.py index fd0bcbd150..360587677f 100644 --- a/tests/unit/sagemaker/image_uris/test_retrieve.py +++ b/tests/unit/sagemaker/image_uris/test_retrieve.py @@ -221,7 +221,6 @@ def test_retrieve_default_version_if_possible(config_for_framework, caplog): image_scope="training", ) assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri - assert "Ignoring framework/algorithm version: invalid-version." in caplog.text @patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG) @@ -239,18 +238,6 @@ def test_retrieve_unsupported_version(config_for_framework): assert "Unsupported some-framework version: 1." in str(e.value) assert "Supported some-framework version(s): 1.0.0, 1.1.0." in str(e.value) - with pytest.raises(ValueError) as e: - image_uris.retrieve( - framework="some-framework", - py_version="py3", - instance_type="ml.c4.xlarge", - region="us-west-2", - image_scope="training", - ) - - assert "Unsupported some-framework version: None." in str(e.value) - assert "Supported some-framework version(s): 1.0.0, 1.1.0." in str(e.value) - @patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG) def test_retrieve_unsupported_region(config_for_framework): @@ -780,3 +767,105 @@ def test_retrieve_with_pipeline_variable(): ), image_scope="training", ) + + +@patch("sagemaker.image_uris.config_for_framework") +def test_get_latest_version_function_with_invalid_framework(config_for_framework): + config_for_framework.side_effect = FileNotFoundError + + with pytest.raises(Exception) as e: + image_uris.retrieve("xgboost", "inference") + assert "No framework config for framework" in str(e.exception) + + +@patch("sagemaker.image_uris.config_for_framework") +def test_get_latest_version_function_with_no_framework(config_for_framework): + config_for_framework.side_effect = {} + + with pytest.raises(Exception) as e: + image_uris.retrieve("xgboost", "inference") + assert "No framework config for framework" in str(e.exception) + + +@pytest.mark.parametrize( + "framework", + [ + "object-detection", + "instance_gpu_info", + "object2vec", + "pytorch", + "djl-lmi", + "mxnet", + "debugger", + "data-wrangler", + "spark", + "blazingtext", + "pytorch-neuron", + "forecasting-deepar", + "huggingface-neuron", + "ntm", + "neo-mxnet", + "image-classification", + "xgboost", + "autogluon", + "sparkml-serving", + "clarify", + "inferentia-pytorch", + "neo-tensorflow", + "huggingface-tei-cpu", + "huggingface", + "sagemaker-tritonserver", + "pytorch-smp", + "knn", + "linear-learner", + "model-monitor", + "ray-tensorflow", + "djl-neuronx", + "huggingface-llm-neuronx", + "image-classification-neo", + "lda", + "stabilityai", + "ray-pytorch", + "chainer", + "coach-mxnet", + "pca", + "sagemaker-geospatial", + "djl-tensorrtllm", + "huggingface-training-compiler", + "pytorch-training-compiler", + "vw", + "huggingface-neuronx", + "ipinsights", + "detailed-profiler", + "inferentia-tensorflow", + "semantic-segmentation", + "inferentia-mxnet", + "xgboost-neo", + "neo-pytorch", + "djl-deepspeed", + "djl-fastertransformer", + "sklearn", + "tensorflow", + "randomcutforest", + "huggingface-llm", + "factorization-machines", + "huggingface-tei", + "coach-tensorflow", + "seq2seq", + "kmeans", + "sagemaker-base-python", + ], +) +@patch("sagemaker.image_uris.config_for_framework") +@patch("sagemaker.image_uris.retrieve") +def test_retrieve_with_parameterized(mock_image_retrieve, mock_config_for_framework, framework): + try: + image_uris.retrieve( + framework=framework, + region="us-east-1", + version=None, + instance_type="ml.c4.xlarge", + image_scope="inference", + ) + except ValueError as e: + pytest.fail(e.value) diff --git a/tests/unit/sagemaker/image_uris/test_sagemaker_distribution.py b/tests/unit/sagemaker/image_uris/test_sagemaker_distribution.py new file mode 100644 index 0000000000..adc51064f1 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/test_sagemaker_distribution.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from sagemaker import image_uris +from tests.unit.sagemaker.image_uris import expected_uris + +INSTANCE_TYPES = {"cpu": "ml.c4.xlarge", "gpu": "ml.p2.xlarge"} + + +def _test_ecr_uri(account, region, version, tag, instance_type, processor): + actual_uri = image_uris.retrieve( + "sagemaker-distribution", region=region, instance_type=instance_type, version=version + ) + expected_uri = expected_uris.sagemaker_distribution_uri( + "sagemaker-distribution-prod", account, tag, processor, region + ) + return expected_uri == actual_uri + + +@pytest.mark.parametrize("load_config", ["sagemaker-distribution.json"], indirect=True) +def test_sagemaker_distribution_ecr_uri(load_config): + VERSIONS = load_config["versions"] + processors = load_config["processors"] + for version in VERSIONS: + SAGEMAKER_DISTRIBUTION_ACCOUNTS = load_config["versions"][version]["registries"] + for region in SAGEMAKER_DISTRIBUTION_ACCOUNTS.keys(): + for processor in processors: + assert _test_ecr_uri( + account=SAGEMAKER_DISTRIBUTION_ACCOUNTS[region], + region=region, + version=version, + tag="3.2.0", + instance_type=INSTANCE_TYPES[processor], + processor=processor, + ) diff --git a/tests/unit/sagemaker/image_uris/test_smp_v2.py b/tests/unit/sagemaker/image_uris/test_smp_v2.py index b53a45133e..3177384e7e 100644 --- a/tests/unit/sagemaker/image_uris/test_smp_v2.py +++ b/tests/unit/sagemaker/image_uris/test_smp_v2.py @@ -27,6 +27,7 @@ def test_smp_v2(load_config): "torch_distributed": {"enabled": True}, "smdistributed": {"modelparallel": {"enabled": True}}, } + for processor in PROCESSORS: for version in VERSIONS: ACCOUNTS = load_config["training"]["versions"][version]["registries"] @@ -35,9 +36,20 @@ def test_smp_v2(load_config): for region in ACCOUNTS.keys(): for instance_type in CONTAINER_VERSIONS.keys(): cuda_vers = CONTAINER_VERSIONS[instance_type] - if "2.1" in version or "2.2" in version: + supported_smp_pt_versions_cu124 = ("2.5",) + supported_smp_pt_versions_cu121 = ("2.1", "2.2", "2.3", "2.4") + if any( + pt_version in version for pt_version in supported_smp_pt_versions_cu124 + ): + cuda_vers = "cu124" + elif any( + pt_version in version for pt_version in supported_smp_pt_versions_cu121 + ): cuda_vers = "cu121" + if version in ("2.3.1", "2.4.1", "2.5.1"): + py_version = "py311" + uri = image_uris.get_training_image_uri( region, framework="pytorch", diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index 2e51afd3f7..2b73766ea4 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -43,7 +43,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode scope="training", sagemaker_session=mock_session, ) - assert default_training_instance_types == "ml.p3.2xlarge" + assert default_training_instance_types == "ml.m5.xlarge" patched_get_model_specs.assert_called_once_with( region=region, @@ -51,6 +51,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -62,7 +64,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode scope="inference", sagemaker_session=mock_session, ) - assert default_inference_instance_types == "ml.p2.xlarge" + assert default_inference_instance_types == "ml.m5.large" patched_get_model_specs.assert_called_once_with( region=region, @@ -70,6 +72,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -81,13 +85,7 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode scope="training", sagemaker_session=mock_session, ) - assert default_training_instance_types == [ - "ml.p3.2xlarge", - "ml.p2.xlarge", - "ml.g4dn.2xlarge", - "ml.m5.xlarge", - "ml.c5.2xlarge", - ] + assert default_training_instance_types == ["ml.m5.xlarge", "ml.c5.2xlarge", "ml.m4.xlarge"] patched_get_model_specs.assert_called_once_with( region=region, @@ -95,6 +93,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -107,13 +107,12 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode sagemaker_session=mock_session, ) assert default_inference_instance_types == [ - "ml.p2.xlarge", - "ml.p3.2xlarge", - "ml.g4dn.xlarge", "ml.m5.large", "ml.m5.xlarge", "ml.c5.xlarge", "ml.c5.2xlarge", + "ml.m4.large", + "ml.m4.xlarge", ] patched_get_model_specs.assert_called_once_with( @@ -122,6 +121,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index f165a513a9..ae02c597da 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -12,8 +12,281 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +# flake8: noqa: E501 SPECIAL_MODEL_SPECS_DICT = { + "js-model-class-model-prepacked": { + "model_id": "huggingface-txt2img-conflictx-complex-lineart", + "url": "https://huggingface.co/Conflictx/Complex-Lineart", + "version": "2.0.3", + "min_sdk_version": "2.189.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface", + "framework_version": "1.10.2", + "py_version": "py38", + "huggingface_transformers_version": "4.17.0", + }, + "hosting_artifact_key": "huggingface-txt2img/huggingface-txt2img-conflictx-complex-lineart/artifacts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/txt2img/v1.1.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-txt2img/huggingface-txt2img-conflictx-complex-lineart/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [ + "accelerate==0.16.0", + "diffusers==0.12.1", + "huggingface_hub==0.12.0", + "transformers==4.26.0", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.p3.2xlarge", + "supported_inference_instance_types": [ + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.g4dn.2xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json", "application/x-text"], + "supported_accept_types": [ + "application/json", + "application/json;verbose", + "application/json;jpeg", + ], + "default_content_type": "application/x-text", + "default_accept_type": "application/json;jpeg", + }, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "hf-txt2img-conflictx-complex-lineart", + "default_payloads": { + "Astronaut": {"content_type": "application/x-text", "body": "astronaut on a horse"} + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "ap-northeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "ap-south-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "ap-southeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "ap-southeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "eu-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "eu-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "sa-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "us-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "us-east-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + "us-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-cpu-py38-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + }, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, + }, "gemma-model": { "model_id": "huggingface-llm-gemma-7b-instruct", "url": "https://huggingface.co/google/gemma-7b-it", @@ -1250,923 +1523,966 @@ "dynamic_container_deployment_supported": True, }, }, - "env-var-variant-model": { - "model_id": "huggingface-llm-falcon-180b-bf16", - "url": "https://huggingface.co/tiiuae/falcon-180B", - "version": "1.0.0", - "min_sdk_version": "2.175.0", - "training_supported": False, + # noqa: E501 + "gemma-model-2b-v1_1_0": { + "model_id": "huggingface-llm-gemma-2b-instruct", + "url": "https://huggingface.co/google/gemma-2b-it", + "version": "1.1.0", + "min_sdk_version": "2.189.0", + "training_supported": True, "incremental_training_supported": False, "hosting_ecr_specs": { "framework": "huggingface-llm", - "framework_version": "0.9.3", - "py_version": "py39", - "huggingface_transformers_version": "4.29.2", + "framework_version": "1.4.2", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_artifact_key": "huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference/v1.0.0/", "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" - "-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.1", + "hosting_prepacked_artifact_key": ( + "huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference-prepack/v1.0.0/" + ), + "hosting_prepacked_artifact_version": "1.0.0", "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/terms/gemmaTerms.txt", "inference_vulnerable": False, "inference_dependencies": [], "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": [], + "training_dependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], "training_vulnerabilities": [], "deprecated": False, - "inference_environment_variables": [ + "hyperparameters": [ { - "name": "SAGEMAKER_PROGRAM", + "name": "peft_type", "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, + "default": "lora", + "options": ["lora", "None"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "name": "instruction_tuned", "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "name": "chat_dataset", "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "ENDPOINT_SERVER_TIMEOUT", + "name": "epoch", "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", }, { - "name": "SAGEMAKER_ENV", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", }, { - "name": "HF_MODEL_ID", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "name": "lora_r", + "type": "int", + "default": 64, + "min": 1, + "max": 1000, + "scope": "algorithm", }, + {"name": "lora_alpha", "type": "int", "default": 16, "min": 0, "scope": "algorithm"}, { - "name": "SM_NUM_GPUS", - "type": "text", - "default": "8", - "scope": "container", - "required_for_model_class": True, + "name": "lora_dropout", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", }, + {"name": "bits", "type": "int", "default": 4, "scope": "algorithm"}, { - "name": "MAX_INPUT_LENGTH", + "name": "double_quant", "type": "text", - "default": "1024", - "scope": "container", - "required_for_model_class": True, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "MAX_TOTAL_TOKENS", + "name": "quant_type", "type": "text", - "default": "2048", - "scope": "container", - "required_for_model_class": True, + "default": "nf4", + "options": ["fp4", "nf4"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "name": "per_device_train_batch_size", "type": "int", "default": 1, - "scope": "container", - "required_for_model_class": True, + "min": 1, + "max": 1000, + "scope": "algorithm", }, - ], - "metrics": [], - "default_inference_instance_type": "ml.p4de.24xlarge", - "supported_inference_instance_types": ["ml.p4de.24xlarge"], - "model_kwargs": {}, - "deploy_kwargs": { - "model_data_download_timeout": 3600, - "container_startup_health_check_timeout": 3600, - }, - "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", - "default_accept_type": "application/json", - }, - "inference_volume_size": 512, - "inference_enable_network_isolation": True, - "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "hf-llm-falcon-180b-bf16", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - } + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 2, + "min": 1, + "max": 1000, + "scope": "algorithm", }, - "variants": { - "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "80"}}}, - "ml.p4d.24xlarge": { - "properties": { - "environment_variables": { - "YODEL": "NACEREMA", - } - } - }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.1, + "min": 0, + "max": 1, + "scope": "algorithm", }, - }, - }, - "inference-instance-types-variant-model": { - "model_id": "huggingface-llm-falcon-180b-bf16", - "url": "https://huggingface.co/tiiuae/falcon-180B", - "version": "1.0.0", - "min_sdk_version": "2.175.0", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "huggingface-llm", - "framework_version": "0.9.3", - "py_version": "py39", - "huggingface_transformers_version": "4.29.2", - }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" - "-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.1", - "hosting_use_script_uri": False, - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "inference_environment_variables": [ { - "name": "SAGEMAKER_PROGRAM", + "name": "train_from_scratch", "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "name": "fp16", "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "name": "bf16", "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "name": "evaluation_strategy", "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, + "default": "steps", + "options": ["steps", "epoch", "no"], + "scope": "algorithm", }, { - "name": "ENDPOINT_SERVER_TIMEOUT", + "name": "eval_steps", "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "default": 20, + "min": 1, + "max": 1000, + "scope": "algorithm", }, { - "name": "SAGEMAKER_ENV", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "name": "gradient_accumulation_steps", + "type": "int", + "default": 4, + "min": 1, + "max": 1000, + "scope": "algorithm", }, { - "name": "HF_MODEL_ID", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "name": "logging_steps", + "type": "int", + "default": 8, + "min": 1, + "max": 1000, + "scope": "algorithm", }, { - "name": "SM_NUM_GPUS", - "type": "text", - "default": "8", - "scope": "container", - "required_for_model_class": True, + "name": "weight_decay", + "type": "float", + "default": 0.2, + "min": 1e-08, + "max": 1, + "scope": "algorithm", }, { - "name": "MAX_INPUT_LENGTH", + "name": "load_best_model_at_end", "type": "text", - "default": "1024", - "scope": "container", - "required_for_model_class": True, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "MAX_TOTAL_TOKENS", + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": 1024, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", "type": "text", - "default": "2048", - "scope": "container", - "required_for_model_class": True, + "default": "None", + "scope": "algorithm", }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "name": "gradient_checkpointing", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "early_stopping_patience", "type": "int", - "default": 1, + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_first_step", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "logging_nan_inf_filter", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "steps", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "int", "default": 1, "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "deepspeed", + "type": "text", + "default": "False", + "options": ["False"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", "scope": "container", - "required_for_model_class": True, }, ], - "metrics": [], - "default_inference_instance_type": "ml.p4de.24xlarge", - "supported_inference_instance_types": ["ml.p4de.24xlarge"], - "default_training_instance_type": "ml.p4de.24xlarge", - "supported_training_instance_types": ["ml.p4de.24xlarge"], - "model_kwargs": {}, - "deploy_kwargs": { - "model_data_download_timeout": 3600, - "container_startup_health_check_timeout": 3600, - }, - "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", - "default_accept_type": "application/json", + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_key": ( + "source-directory-tarballs/huggingface/transfer_learning/llm/prepack/v1.1.1/sourcedir.tar.gz" + ), + "training_prepacked_script_version": "1.1.1", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", }, - "inference_volume_size": 512, - "inference_enable_network_isolation": True, - "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "hf-llm-falcon-180b-bf16", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "gpu_image_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/stud-gpu", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - } - }, - "variants": { - "ml.p2.12xlarge": { - "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "supported_inference_instance_types": ["ml.p5.xlarge"], - "default_inference_instance_type": "ml.p5.xlarge", - "metrics": [ - { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:instance-typemetric-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:noneyourbusiness-loss", - "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", - }, - ], - } - }, - "p2": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], - "default_inference_instance_type": "ml.p2.xlarge", - "metrics": [ - { - "Name": "huggingface-textgeneration:wtafigo", - "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:noneyourbusiness-loss", - "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", - }, - ], - }, - }, - "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "ml.p3.200xlarge": {"regional_properties": {"image_uri": "$gpu_image_uri_2"}}, - "p4": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/number2/" - }, - }, - "g4": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" - }, - }, - "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "g9": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "prepacked_artifact_key": "asfs/adsf/sda/f", - "hyperparameters": [ - { - "name": "num_bag_sets", - "type": "int", - "default": 5, - "min": 5, - "scope": "algorithm", - }, - { - "name": "num_stack_levels", - "type": "int", - "default": 6, - "min": 7, - "max": 3, - "scope": "algorithm", - }, - { - "name": "refit_full", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "set_best_to_refit_full", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "save_space", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "verbosity", - "type": "int", - "default": 2, - "min": 0, - "max": 4, - "scope": "algorithm", - }, - { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", - }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", - }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", - }, - ], - }, - }, - "p9": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": {"artifact_key": "do/re/mi"}, - }, - "m2": { - "regional_properties": {"image_uri": "$cpu_image_uri"}, - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "400"}}, - }, - "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "local": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "ml.g5.48xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} - }, - "ml.g5.12xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} - }, - "g5": { - "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4", "JOHN": "DOE"} - } - }, - "ml.g9.12xlarge": { - "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "prepacked_artifact_key": "nlahdasf/asdf/asd/f", - "hyperparameters": [ - { - "name": "eval_metric", - "type": "text", - "default": "auto", - "scope": "algorithm", - }, - { - "name": "presets", - "type": "text", - "default": "medium_quality", - "options": [ - "best_quality", - "high_quality", - "good_quality", - "medium_quality", - "optimize_for_deployment", - "interpretable", - ], - "scope": "algorithm", - }, - { - "name": "auto_stack", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "num_bag_folds", - "type": "text", - "default": "0", - "options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"], - "scope": "algorithm", - }, - { - "name": "num_bag_sets", - "type": "int", - "default": 1, - "min": 1, - "scope": "algorithm", - }, - { - "name": "num_stack_levels", - "type": "int", - "default": 0, - "min": 0, - "max": 3, - "scope": "algorithm", - }, - ], - } - }, - "ml.p9.12xlarge": { - "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "artifact_key": "you/not/entertained", - } - }, - "g6": { - "properties": { - "environment_variables": {"BLAH": "4"}, - "artifact_key": "path/to/training/artifact.tar.gz", - "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", - } - }, - "trn1": { - "properties": { - "supported_inference_instance_types": ["ml.inf1.xlarge", "ml.inf1.2xlarge"], - "default_inference_instance_type": "ml.inf1.xlarge", - } - }, - }, - }, - "training_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", - }, - "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", - "training_prepacked_script_key": None, - "training_model_package_artifact_uris": None, - "deprecate_warn_message": None, - "deprecated_message": None, - "hosting_eula_key": None, - "hyperparameters": [ - { - "name": "epochs", - "type": "int", - "default": 3, - "min": 1, - "max": 1000, - "scope": "algorithm", - }, - { - "name": "adam-learning-rate", - "type": "float", - "default": 0.05, - "min": 1e-08, - "max": 1, - "scope": "algorithm", - }, - { - "name": "batch-size", - "type": "int", - "default": 4, - "min": 1, - "max": 1024, - "scope": "algorithm", + "training_artifact_key": "huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, }, { - "name": "sagemaker_submit_directory", + "name": "SAGEMAKER_SUBMIT_DIRECTORY", "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "default": "/opt/ml/model/code", "scope": "container", + "required_for_model_class": False, }, { - "name": "sagemaker_program", + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", "type": "text", - "default": "transfer_learning.py", + "default": "20", "scope": "container", + "required_for_model_class": False, }, { - "name": "sagemaker_container_log_level", + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", "type": "text", - "default": "20", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "8192", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_BATCH_PREFILL_TOKENS", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, "scope": "container", + "required_for_model_class": True, }, ], - "training_vulnerable": False, - "deprecated": False, + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + {"Name": "huggingface-textgeneration:train-loss", "Regex": "'loss': ([0-9]+\\.[0-9]+)"}, + ], + "default_inference_instance_type": "ml.g5.xlarge", + "supported_inference_instance_types": [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.2xlarge", + "supported_training_instance_types": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, "estimator_kwargs": { "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, }, - "training_volume_size": 456, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "training_volume_size": 512, "inference_enable_network_isolation": True, - "training_enable_network_isolation": False, - }, - "variant-model": { - "model_id": "pytorch-ic-mobilenet-v2", - "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", - "version": "1.0.0", - "min_sdk_version": "2.49.0", - "training_supported": True, - "incremental_training_supported": True, - "hosting_model_package_arns": { - "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" - "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/oasst_top/train/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-llm-gemma-2b-instruct", + "default_payloads": { + "HelloWorld": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": ( + "user\nWrite a hello world program\nmodel" + ), + "parameters": { + "max_new_tokens": 256, + "decoder_input_details": True, + "details": True, + }, + }, + }, + "MachineLearningPoem": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": "Write me a poem about Machine Learning.", + "parameters": { + "max_new_tokens": 256, + "decoder_input_details": True, + "details": True, + }, + }, + }, }, - "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": ( + "626614931356.dkr.ecr.af-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-east-1": { + "gpu_ecr_uri_1": ( + "871362719292.dkr.ecr.ap-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-south-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ca-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ca-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "cn-north-1": { + "gpu_ecr_uri_1": ( + "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-north-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-north-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-south-1": { + "gpu_ecr_uri_1": ( + "692866216735.dkr.ecr.eu-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-west-3": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-3.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "il-central-1": { + "gpu_ecr_uri_1": ( + "780543022126.dkr.ecr.il-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "me-south-1": { + "gpu_ecr_uri_1": ( + "217643126080.dkr.ecr.me-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "sa-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.sa-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-east-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g4dn.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, }, "training_instance_type_variants": { - "regional_aliases": {}, - "variants": { - "ml.p2.12xlarge": { - "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "hyperparameters": [ - { - "name": "eval_metric", - "type": "text", - "default": "auto", - "scope": "algorithm", - }, - { - "name": "presets", - "type": "text", - "default": "medium_quality", - "options": [ - "best_quality", - "high_quality", - "good_quality", - "medium_quality", - "optimize_for_deployment", - "interpretable", - ], - "scope": "algorithm", - }, - { - "name": "auto_stack", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "num_bag_folds", - "type": "text", - "default": "0", - "options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"], - "scope": "algorithm", - }, - { - "name": "num_bag_sets", - "type": "int", - "default": 1, - "min": 1, - "scope": "algorithm", - }, - { - "name": "batch-size", - "type": "int", - "default": 1, - "min": 1, - "scope": "algorithm", - }, - { - "name": "num_stack_levels", - "type": "int", - "default": 0, - "min": 0, - "max": 3, - "scope": "algorithm", - }, - ], - "metrics": [ - { - "Name": "huggingface-textgeneration:instance-typemetric-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:noneyourbusiness-loss", - "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", - }, - ], - } + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": ( + "626614931356.dkr.ecr.af-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, - "p2": { - "regional_properties": {"image_uri": "$gpu_ecr_uri_2"}, - "properties": { - "hyperparameters": [ - { - "name": "num_bag_sets", - "type": "int", - "default": 5, - "min": 5, - "scope": "algorithm", - }, - { - "name": "num_stack_levels", - "type": "int", - "default": 6, - "min": 7, - "max": 3, - "scope": "algorithm", - }, - { - "name": "refit_full", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "set_best_to_refit_full", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "save_space", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "verbosity", - "type": "int", - "default": 2, - "min": 0, - "max": 4, - "scope": "algorithm", - }, - { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", - }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", - }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", - }, - ], - "metrics": [ - { - "Name": "huggingface-textgeneration:wtafigo", - "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:noneyourbusiness-loss", - "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", - }, - ], - }, + "ap-east-1": { + "gpu_ecr_uri_1": ( + "871362719292.dkr.ecr.ap-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, - }, - }, - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - "inf_model_package_arn": "us-west-2/blah/blah/blah/inf", - "gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu", - } - }, - "variants": { - "p2": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } + "ap-northeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, - "p3": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } + "ap-northeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, - "p4": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } + "ap-south-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ca-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ca-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "cn-north-1": { + "gpu_ecr_uri_1": ( + "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-north-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-north-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-south-1": { + "gpu_ecr_uri_1": ( + "692866216735.dkr.ecr.eu-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-west-3": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-3.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) }, + "il-central-1": { + "gpu_ecr_uri_1": ( + "780543022126.dkr.ecr.il-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "me-south-1": { + "gpu_ecr_uri_1": ( + "217643126080.dkr.ecr.me-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "sa-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.sa-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-east-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + }, + "variants": { "g4dn": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": ( + "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, }, "g5": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, "properties": { - "resource_requirements": { - "num_accelerators": 888810, - "randon-field-2": 2222, - } - } + "gated_model_key_env_var_value": ( + "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, }, - "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "ml.g5.xlarge": { + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}, - "resource_requirements": {"num_accelerators": 10}, - } - }, - "ml.g5.48xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} + "gated_model_key_env_var_value": ( + "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, }, - "ml.g5.12xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} + "p4d": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": ( + "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, }, - "inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, - "inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, }, - "training_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", - }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "hosting_resource_requirements": {"min_memory_mb": 8192, "num_accelerators": 1}, "dynamic_container_deployment_supported": True, - "hosting_resource_requirements": { - "min_memory_mb": 81999, - "num_accelerators": 1, - "random_field_1": 1, + }, + # noqa: E501 + "env-var-variant-model": { + "model_id": "huggingface-llm-falcon-180b-bf16", + "url": "https://huggingface.co/tiiuae/falcon-180B", + "version": "1.6.2", + "min_sdk_version": "2.188.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "1.4.0", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", }, - "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", - "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", - "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", - "training_prepacked_script_key": None, - "hosting_prepacked_artifact_key": None, - "training_model_package_artifact_uris": None, - "deprecate_warn_message": None, - "deprecated_message": None, - "hosting_eula_key": None, - "hyperparameters": [ - { - "name": "epochs", - "type": "int", - "default": 3, - "min": 1, - "max": 1000, - "scope": "algorithm", - }, - { - "name": "adam-learning-rate", - "type": "float", - "default": 0.05, - "min": 1e-08, - "max": 1, - "scope": "algorithm", - }, - { - "name": "batch-size", - "type": "int", - "default": 4, - "min": 1, - "max": 1024, - "scope": "algorithm", - }, - { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", - }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", - }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", - }, - ], + "hosting_artifact_key": "huggingface-infer/v1.2.0/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.2.0/infer-prepack-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_prepacked_artifact_version": "1.2.0", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", @@ -2218,287 +2534,430 @@ "required_for_model_class": True, }, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", "scope": "container", "required_for_model_class": True, }, - ], - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "default_inference_instance_type": "ml.p2.xlarge", - "supported_inference_instance_types": [ - "ml.p2.xlarge", - "ml.p3.2xlarge", - "ml.g4dn.xlarge", - "ml.m5.large", - "ml.m5.xlarge", - "ml.c5.xlarge", - "ml.c5.2xlarge", - ], - "default_training_instance_type": "ml.p3.2xlarge", - "supported_training_instance_types": [ - "ml.p3.2xlarge", - "ml.p2.xlarge", - "ml.g4dn.2xlarge", - "ml.m5.xlarge", - "ml.c5.2xlarge", - ], - "hosting_use_script_uri": True, - "metrics": [ { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + "name": "SM_NUM_GPUS", + "type": "text", + "default": "8", + "scope": "container", + "required_for_model_class": True, }, { - "Name": "huggingface-textgeyyyuyuyuyneration:train-loss", - "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "1024", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "2048", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, }, ], - "model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"}, - "deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"}, - "estimator_kwargs": { - "encrypt_inter_container_traffic": True, + "metrics": [], + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge", "ml.p5.48xlarge"], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, "predictor_specs": { - "supported_content_types": ["application/x-image"], - "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-image", + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", "default_accept_type": "application/json", }, - "inference_volume_size": 123, - "training_volume_size": 456, + "inference_volume_size": 512, "inference_enable_network_isolation": True, - "training_enable_network_isolation": False, - "resource_name_base": "dfsdfsds", - }, - "gated_llama_neuron_model": { - "model_id": "meta-textgenerationneuron-llama-2-7b", - "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", - "version": "1.0.0", - "min_sdk_version": "2.198.0", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "djl-neuronx", - "framework_version": "0.24.0", - "py_version": "py39", - }, - "hosting_artifact_key": "meta-textgenerationneuron/meta-textgenerationneuron-llama-2-7b/artifac" - "ts/inference/v1.0.0/", - "hosting_script_key": "source-directory-tarballs/meta/inference/textgenerationneuron/v1.0.0/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "meta-textgenerationneuron/meta-textgenerationneuro" - "n-llama-2-7b/artifacts/inference-prepack/v1.0.0/", - "hosting_prepacked_artifact_version": "1.0.0", - "hosting_use_script_uri": False, - "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", - "inference_vulnerable": False, - "inference_dependencies": [ - "sagemaker_jumpstart_huggingface_script_utilities==1.0.8", - "sagemaker_jumpstart_script_utilities==1.1.8", - ], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [ - "sagemaker_jumpstart_huggingface_script_utilities==1.1.3", - "sagemaker_jumpstart_script_utilities==1.1.9", - "sagemaker_jumpstart_tabular_script_utilities==1.0.0", - ], - "training_vulnerabilities": [], - "deprecated": False, - "hyperparameters": [ - { - "name": "max_input_length", - "type": "int", - "default": 2048, - "min": 128, - "scope": "algorithm", + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "hf-llm-falcon-180b-bf16", + "default_payloads": { + "Girafatron": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", + "parameters": { + "max_new_tokens": 50, + "return_full_text": False, + "do_sample": True, + "top_k": 10, + "stop": ["Daniel:"], + "decoder_input_details": True, + "details": True, + }, + }, }, - { - "name": "preprocessing_num_workers", - "type": "text", - "default": "None", - "scope": "algorithm", + "Factorial": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": "Write a program to compute factorial in python:", + "parameters": { + "max_new_tokens": 200, + "decoder_input_details": True, + "details": True, + }, + }, }, - { - "name": "learning_rate", - "type": "float", - "default": 6e-06, - "min": 1e-08, - "max": 1, - "scope": "algorithm", + "Website": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": "Building a website can be done in 10 simple steps:", + "parameters": { + "max_new_tokens": 256, + "decoder_input_details": True, + "details": True, + }, + }, }, - { - "name": "min_learning_rate", - "type": "float", - "default": 1e-06, - "min": 1e-12, - "max": 1, - "scope": "algorithm", - }, - {"name": "max_steps", "type": "int", "default": 20, "min": 2, "scope": "algorithm"}, - { - "name": "global_train_batch_size", - "type": "int", - "default": 256, - "min": 1, - "scope": "algorithm", + "TranslateEnglishToFrench": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": "Translate English to French:\n\nsea otter => loutre de mer\n\npeppermint => menthe poivr\u00e9e\n\nplush girafe => girafe peluche\n\ncheese =>", + "parameters": { + "max_new_tokens": 3, + "decoder_input_details": True, + "details": True, + }, + }, }, - { - "name": "per_device_train_batch_size", - "type": "int", - "default": 1, - "min": 1, - "scope": "algorithm", + "SentimentAnalysis": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": '"I hate it when my phone battery dies."\nSentiment: Negative\n###\nTweet: "My day has been :+1:"\nSentiment: Positive\n###\nTweet: "This is the link to the article"\nSentiment: Neutral\n###\nTweet: "This new music video was incredibile"\nSentiment:', + "parameters": { + "max_new_tokens": 2, + "decoder_input_details": True, + "details": True, + }, + }, }, - { - "name": "layer_norm_epilson", - "type": "float", - "default": 1e-05, - "min": 1e-12, - "scope": "algorithm", + "QuestionAnswering": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": "Could you remind me when was the C programming language invented?", + "parameters": { + "max_new_tokens": 50, + "decoder_input_details": True, + "details": True, + }, + }, }, - { - "name": "weight_decay", - "type": "float", - "default": 0.1, - "min": 1e-08, - "max": 1, - "scope": "algorithm", + "RecipeGeneration": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": "What is the recipe for a delicious lemon cheesecake?", + "parameters": { + "max_new_tokens": 256, + "decoder_input_details": True, + "details": True, + }, + }, }, - { - "name": "lr_scheduler_type", - "type": "text", - "default": "CosineAnnealing", - "options": ["CosineAnnealing"], - "scope": "algorithm", + "Summarization": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": "Starting today, the state-of-the-art Falcon 40B foundation model from Technology\nInnovation Institute (TII) is available on Amazon SageMaker JumpStart, SageMaker's machine learning (ML) hub\nthat offers pre-trained models, built-in algorithms, and pre-built solution templates to help you quickly get\nstarted with ML. You can deploy and use this Falcon LLM with a few clicks in SageMaker Studio or\nprogrammatically through the SageMaker Python SDK.\nFalcon 40B is a 40-billion-parameter large language model (LLM) available under the Apache 2.0 license that\nranked #1 in Hugging Face Open LLM leaderboard, which tracks, ranks, and evaluates LLMs across multiple\nbenchmarks to identify top performing models. Since its release in May 2023, Falcon 40B has demonstrated\nexceptional performance without specialized fine-tuning. To make it easier for customers to access this\nstate-of-the-art model, AWS has made Falcon 40B available to customers via Amazon SageMaker JumpStart.\nNow customers can quickly and easily deploy their own Falcon 40B model and customize it to fit their specific\nneeds for applications such as translation, question answering, and summarizing information.\nFalcon 40B are generally available today through Amazon SageMaker JumpStart in US East (Ohio),\nUS East (N. Virginia), US West (Oregon), Asia Pacific (Tokyo), Asia Pacific (Seoul), Asia Pacific (Mumbai),\nEurope (London), Europe (Frankfurt), Europe (Ireland), and Canada (Central),\nwith availability in additional AWS Regions coming soon. To learn how to use this new feature,\nplease see SageMaker JumpStart documentation, the Introduction to SageMaker JumpStart \u2013\nText Generation with Falcon LLMs example notebook, and the blog Technology Innovation Institute trains\nthe state-of-the-art Falcon LLM 40B foundation model on Amazon SageMaker. Summarize the article above:", + "parameters": { + "max_new_tokens": 256, + "decoder_input_details": True, + "details": True, + }, + }, }, - {"name": "warmup_steps", "type": "int", "default": 10, "min": 0, "scope": "algorithm"}, - {"name": "constant_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, - { - "name": "adam_beta1", - "type": "float", - "default": 0.9, - "min": 1e-08, - "max": 1, - "scope": "algorithm", + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-3": { + "gpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-3": { + "gpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "cn-northwest-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-central-2": { + "gpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "il-central-1": { + "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "me-central-1": { + "gpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-gov-east-1": { + "gpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-gov-west-1": { + "gpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, }, - { - "name": "adam_beta2", - "type": "float", - "default": 0.95, - "min": 1e-08, - "max": 1, - "scope": "algorithm", + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g4dn.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4de.24xlarge": { + "properties": { + "environment_variables": {"SM_NUM_GPUS": "8"}, + "resource_requirements": {"min_memory_mb": 589824, "num_accelerators": 8}, + } + }, + "ml.p5.48xlarge": { + "properties": { + "resource_requirements": {"min_memory_mb": 1048576, "num_accelerators": 8} + } + }, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "80"}}}, + "ml.p4d.24xlarge": { + "properties": { + "environment_variables": { + "YODEL": "NACEREMA", + } + } + }, }, + }, + "hosting_resource_requirements": {"min_memory_mb": 589824, "num_accelerators": 8}, + "dynamic_container_deployment_supported": True, + "bedrock_console_supported": True, + "bedrock_io_mapping_id": "tgi_default_1.0.0", + }, + "inference-instance-types-variant-model": { + "model_id": "huggingface-llm-falcon-180b-bf16", + "url": "https://huggingface.co/tiiuae/falcon-180B", + "version": "1.0.0", + "min_sdk_version": "2.175.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "0.9.3", + "py_version": "py39", + "huggingface_transformers_version": "4.29.2", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" + "-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.1", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ { - "name": "mixed_precision", + "name": "SAGEMAKER_PROGRAM", "type": "text", - "default": "True", - "options": ["True", "False"], - "scope": "algorithm", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, }, { - "name": "tensor_parallel_degree", + "name": "SAGEMAKER_SUBMIT_DIRECTORY", "type": "text", - "default": "8", - "options": ["8"], - "scope": "algorithm", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, }, { - "name": "pipeline_parallel_degree", + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", "type": "text", - "default": "1", - "options": ["1"], - "scope": "algorithm", + "default": "20", + "scope": "container", + "required_for_model_class": False, }, { - "name": "append_eod", + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", + "default": "3600", + "scope": "container", + "required_for_model_class": False, }, { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", - }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", - }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, "scope": "container", + "required_for_model_class": True, }, - ], - "training_script_key": "source-directory-tarballs/meta/transfer_learning/textgenerati" - "onneuron/v1.0.0/sourcedir.tar.gz", - "training_prepacked_script_key": "source-directory-tarballs/meta/tra" - "nsfer_learning/textgenerationneuron/prepack/v1.0.0/sourcedir.tar.gz", - "training_prepacked_script_version": "1.0.0", - "training_ecr_specs": { - "framework": "huggingface", - "framework_version": "2.0.0", - "py_version": "py310", - "huggingface_transformers_version": "4.28.1", - }, - "training_artifact_key": "meta-training/train-meta-textgenerationneuron-llama-2-7b.tar.gz", - "inference_environment_variables": [ { - "name": "SAGEMAKER_PROGRAM", + "name": "MODEL_CACHE_ROOT", "type": "text", - "default": "inference.py", + "default": "/opt/ml/model", "scope": "container", "required_for_model_class": True, }, { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "name": "SAGEMAKER_ENV", "type": "text", - "default": "/opt/ml/model/code", + "default": "1", "scope": "container", - "required_for_model_class": False, + "required_for_model_class": True, }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "name": "HF_MODEL_ID", "type": "text", - "default": "20", + "default": "/opt/ml/model", "scope": "container", - "required_for_model_class": False, + "required_for_model_class": True, }, { - "name": "MODEL_CACHE_ROOT", + "name": "SM_NUM_GPUS", "type": "text", - "default": "/opt/ml/model", + "default": "8", "scope": "container", - "required_for_model_class": False, + "required_for_model_class": True, }, { - "name": "SAGEMAKER_ENV", + "name": "MAX_INPUT_LENGTH", "type": "text", - "default": "1", + "default": "1024", "scope": "container", - "required_for_model_class": False, + "required_for_model_class": True, }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "name": "MAX_TOTAL_TOKENS", "type": "text", - "default": "3600", + "default": "2048", "scope": "container", - "required_for_model_class": False, + "required_for_model_class": True, }, { "name": "SAGEMAKER_MODEL_SERVER_WORKERS", @@ -2508,440 +2967,279 @@ "required_for_model_class": True, }, ], - "metrics": [ - { - "Name": "meta-textgenerationneuron:train-loss", - "Regex": "reduced_train_loss=([0-9]+\\.[0-9]+)", - } - ], - "default_inference_instance_type": "ml.inf2.xlarge", - "supported_inference_instance_types": [ - "ml.inf2.xlarge", - "ml.inf2.8xlarge", - "ml.inf2.24xlarge", - "ml.inf2.48xlarge", - ], - "default_training_instance_type": "ml.trn1.32xlarge", - "supported_training_instance_types": ["ml.trn1.32xlarge", "ml.trn1n.32xlarge"], + "metrics": [], + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "default_training_instance_type": "ml.p4de.24xlarge", + "supported_training_instance_types": ["ml.p4de.24xlarge"], "model_kwargs": {}, "deploy_kwargs": { "model_data_download_timeout": 3600, "container_startup_health_check_timeout": 3600, }, - "estimator_kwargs": { - "encrypt_inter_container_traffic": True, - "disable_output_compression": True, - "max_run": 360000, - }, - "fit_kwargs": {}, "predictor_specs": { "supported_content_types": ["application/json"], "supported_accept_types": ["application/json"], "default_content_type": "application/json", "default_accept_type": "application/json", }, - "inference_volume_size": 256, - "training_volume_size": 256, + "inference_volume_size": 512, "inference_enable_network_isolation": True, - "training_enable_network_isolation": True, - "default_training_dataset_key": "training-datasets/sec_amazon/", "validation_supported": False, - "fine_tuning_supported": True, - "resource_name_base": "meta-textgenerationneuron-llama-2-7b", - "default_payloads": { - "meaningOfLife": { - "content_type": "application/json", - "prompt_key": "inputs", - "output_keys": {"generated_text": "generated_text"}, - "body": { - "inputs": "I believe the meaning of life is", - "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, - }, - }, - "theoryOfRelativity": { - "content_type": "application/json", - "prompt_key": "inputs", - "output_keys": {"generated_text": "generated_text"}, - "body": { - "inputs": "Simply put, the theory of relativity states that ", - "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, - }, - }, - "teamMessage": { - "content_type": "application/json", - "prompt_key": "inputs", - "output_keys": {"generated_text": "generated_text"}, - "body": { - "inputs": "A brief message congratulating the team on the launch:\n\nHi " - "everyone,\n\nI just ", - "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, - }, - }, - "englishToFrench": { - "content_type": "application/json", - "prompt_key": "inputs", - "output_keys": {"generated_text": "generated_text"}, - "body": { - "inputs": "Translate English to French:\nsea otter => loutre de mer\npep" - "permint => menthe poivr\u00e9e\nplush girafe => girafe peluche\ncheese =>", - "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, - }, - }, - }, - "gated_bucket": True, - "hosting_instance_type_variants": { + "fine_tuning_supported": False, + "resource_name_base": "hf-llm-falcon-180b-bf16", + "training_instance_type_variants": { "regional_aliases": { - "af-south-1": { - "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/djl-in" - "ference:0.24.0-neuronx-sdk2.14.1" - }, - "ap-east-1": { - "alias_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/djl-in" - "ference:0.24.0-neuronx-sdk2.14.1" - }, - "ap-northeast-1": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/d" - "jl-inference:0.24.0-neuronx-sdk2.14.1" - }, - "ap-northeast-2": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com" - "/djl-inference:0.24.0-neuronx-sdk2.14.1" - }, - "ap-south-1": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/" - "djl-inference:0.24.0-neuronx-sdk2.14.1" - }, - "ap-southeast-1": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com" - "/djl-inference:0.24.0-neuronx-sdk2.14.1" - }, - "ap-southeast-2": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com" - "/djl-inference:0.24.0-neuronx-sdk2.14.1" - }, - "ca-central-1": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/" - "djl-inference:0.24.0-neuronx-sdk2.14.1" - }, - "cn-north-1": { - "alias_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" - "djl-inference:0.24.0-neuronx-sdk2.14.1" + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "gpu_image_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/stud-gpu", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + } + }, + "variants": { + "ml.p2.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "supported_inference_instance_types": ["ml.p5.xlarge"], + "default_inference_instance_type": "ml.p5.xlarge", + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:instance-typemetric-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", + }, + ], + } }, - "eu-central-1": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/" - "djl-inference:0.24.0-neuronx-sdk2.14.1" + "p2": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], + "default_inference_instance_type": "ml.p2.xlarge", + "metrics": [ + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", + }, + ], + }, }, - "eu-north-1": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/" - "djl-inference:0.24.0-neuronx-sdk2.14.1" + "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "ml.p3.200xlarge": {"regional_properties": {"image_uri": "$gpu_image_uri_2"}}, + "p4": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/number2/" + }, }, - "eu-south-1": { - "alias_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/" - "djl-inference:0.24.0-neuronx-sdk2.14.1" + "g4": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "training_artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + }, }, - "eu-west-1": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/d" - "jl-inference:0.24.0-neuronx-sdk2.14.1" - }, - "eu-west-2": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/d" - "jl-inference:0.24.0-neuronx-sdk2.14.1" - }, - "eu-west-3": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/d" - "jl-inference:0.24.0-neuronx-sdk2.14.1" - }, - "me-south-1": { - "alias_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com" - "/djl-inference:0.24.0-neuronx-sdk2.14.1" + "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "g9": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "prepacked_artifact_key": "asfs/adsf/sda/f", + "hyperparameters": [ + { + "name": "num_bag_sets", + "type": "int", + "default": 5, + "min": 5, + "scope": "algorithm", + }, + { + "name": "num_stack_levels", + "type": "int", + "default": 6, + "min": 7, + "max": 3, + "scope": "algorithm", + }, + { + "name": "refit_full", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "set_best_to_refit_full", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_space", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "verbosity", + "type": "int", + "default": 2, + "min": 0, + "max": 4, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + }, }, - "sa-east-1": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com" - "/djl-inference:0.24.0-neuronx-sdk2.14.1" + "p9": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": {"training_artifact_key": "do/re/mi"}, }, - "us-east-1": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/" - "djl-inference:0.24.0-neuronx-sdk2.14.1" + "m2": { + "regional_properties": {"image_uri": "$cpu_image_uri"}, + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "400"}}, }, - "us-east-2": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com" - "/djl-inference:0.24.0-neuronx-sdk2.14.1" + "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "local": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "ml.g5.48xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} }, - "us-west-1": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.co" - "m/djl-inference:0.24.0-neuronx-sdk2.14.1" + "ml.g5.12xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} }, - "us-west-2": { - "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-" - "inference:0.24.0-neuronx-sdk2.14.1" + "g5": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4", "JOHN": "DOE"} + } }, - }, - "variants": { - "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "ml.inf2.xlarge": { + "ml.g9.12xlarge": { "properties": { - "environment_variables": { - "OPTION_TENSOR_PARALLEL_DEGREE": "2", - "OPTION_N_POSITIONS": "1024", - "OPTION_DTYPE": "fp16", - "OPTION_ROLLING_BATCH": "auto", - "OPTION_MAX_ROLLING_BATCH_SIZE": "1", - "OPTION_NEURON_OPTIMIZE_LEVEL": "2", - } + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "prepacked_artifact_key": "nlahdasf/asdf/asd/f", + "hyperparameters": [ + { + "name": "eval_metric", + "type": "text", + "default": "auto", + "scope": "algorithm", + }, + { + "name": "presets", + "type": "text", + "default": "medium_quality", + "options": [ + "best_quality", + "high_quality", + "good_quality", + "medium_quality", + "optimize_for_deployment", + "interpretable", + ], + "scope": "algorithm", + }, + { + "name": "auto_stack", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "num_bag_folds", + "type": "text", + "default": "0", + "options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "scope": "algorithm", + }, + { + "name": "num_bag_sets", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "num_stack_levels", + "type": "int", + "default": 0, + "min": 0, + "max": 3, + "scope": "algorithm", + }, + ], } }, - "ml.inf2.8xlarge": { + "ml.p9.12xlarge": { "properties": { - "environment_variables": { - "OPTION_TENSOR_PARALLEL_DEGREE": "2", - "OPTION_N_POSITIONS": "2048", - "OPTION_DTYPE": "fp16", - "OPTION_ROLLING_BATCH": "auto", - "OPTION_MAX_ROLLING_BATCH_SIZE": "4", - "OPTION_NEURON_OPTIMIZE_LEVEL": "2", - } + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "training_artifact_key": "you/not/entertained", } }, - "ml.inf2.24xlarge": { + "g6": { "properties": { - "environment_variables": { - "OPTION_TENSOR_PARALLEL_DEGREE": "12", - "OPTION_N_POSITIONS": "4096", - "OPTION_DTYPE": "fp16", - "OPTION_ROLLING_BATCH": "auto", - "OPTION_MAX_ROLLING_BATCH_SIZE": "4", - "OPTION_NEURON_OPTIMIZE_LEVEL": "2", - } + "environment_variables": {"BLAH": "4"}, + "training_artifact_key": "path/to/training/artifact.tar.gz", + "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", } }, - "ml.inf2.48xlarge": { + "trn1": { "properties": { - "environment_variables": { - "OPTION_TENSOR_PARALLEL_DEGREE": "24", - "OPTION_N_POSITIONS": "4096", - "OPTION_DTYPE": "fp16", - "OPTION_ROLLING_BATCH": "auto", - "OPTION_MAX_ROLLING_BATCH_SIZE": "4", - "OPTION_NEURON_OPTIMIZE_LEVEL": "2", - } + "supported_inference_instance_types": ["ml.inf1.xlarge", "ml.inf1.2xlarge"], + "default_inference_instance_type": "ml.inf1.xlarge", } }, }, }, - "training_instance_type_variants": { - "regional_aliases": { - "af-south-1": { - "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorc" - "h-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "ap-east-1": { - "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch" - "-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "ap-northeast-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-" - "pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "ap-northeast-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingfa" - "ce-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "ap-south-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch" - "-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "ap-southeast-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/" - "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "ap-southeast-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/" - "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "ca-central-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/hu" - "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "cn-north-1": { - "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" - "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "eu-central-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/hug" - "gingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "eu-north-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/hu" - "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "eu-south-1": { - "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/hu" - "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "eu-west-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/hug" - "gingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "eu-west-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/hug" - "gingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "eu-west-3": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggi" - "ngface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "me-south-1": { - "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggin" - "gface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "sa-east-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/hugg" - "ingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "us-east-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/hu" - "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", - "neuron_ecr_uri": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-" - "training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", - }, - "us-east-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-py" - "torch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", - "neuron_ecr_uri": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-trai" - "ning-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", - }, - "us-west-1": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytor" - "ch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - "us-west-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface" - "-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", - "neuron_ecr_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch" - "-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", - }, - }, - "variants": { - "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "trn1": { - "regional_properties": {"image_uri": "$neuron_ecr_uri"}, - "properties": { - "gated_model_key_env_var_value": "meta-training/trn1/v1.0." - "0/train-meta-textgenerationneuron-llama-2-7b.tar.gz" - }, - }, - "trn1n": { - "regional_properties": {"image_uri": "$neuron_ecr_uri"}, - "properties": { - "gated_model_key_env_var_value": "meta-training/trn1n/v1.0.0" - "/train-meta-textgenerationneuron-llama-2-7b.tar.gz" - }, - }, - }, - }, - "hosting_artifact_s3_data_type": "S3Prefix", - "hosting_artifact_compression_type": "None", - "hosting_resource_requirements": {"min_memory_mb": 8192, "num_accelerators": 1}, - "dynamic_container_deployment_supported": True, - }, - "gated_variant-model": { - "model_id": "pytorch-ic-mobilenet-v2", - "gated_bucket": True, - "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", - "version": "1.0.0", - "min_sdk_version": "2.49.0", - "training_supported": True, - "incremental_training_supported": True, - "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", - }, - "training_instance_type_variants": None, - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - } - }, - "variants": { - "p2": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - }, - "properties": { - "prepacked_artifact_key": "some-instance-specific/model/prefix/" - }, - }, - "p3": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - } - }, - "p4": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - } - }, - "g4dn": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - } - }, - "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "ml.g5.48xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} - }, - "ml.g5.12xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} - }, - }, - }, "training_ecr_specs": { "framework": "pytorch", "framework_version": "1.5.0", "py_version": "py3", }, - "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", "training_prepacked_script_key": None, - "hosting_prepacked_artifact_key": None, "training_model_package_artifact_uris": None, "deprecate_warn_message": None, "deprecated_message": None, @@ -2990,152 +3288,264 @@ "scope": "container", }, ], - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "ENDPOINT_SERVER_TIMEOUT", - "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "SAGEMAKER_ENV", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, - "scope": "container", - "required_for_model_class": True, - }, - ], - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], "deprecated": False, - "default_inference_instance_type": "ml.p2.xlarge", - "supported_inference_instance_types": [ - "ml.p2.xlarge", - "ml.p3.2xlarge", - "ml.g4dn.xlarge", - "ml.m5.large", - "ml.m5.xlarge", - "ml.c5.xlarge", - "ml.c5.2xlarge", - ], - "default_training_instance_type": "ml.p3.2xlarge", - "supported_training_instance_types": [ - "ml.p3.2xlarge", - "ml.p2.xlarge", - "ml.g4dn.2xlarge", - "ml.m5.xlarge", - "ml.c5.2xlarge", - ], - "hosting_use_script_uri": False, - "metrics": [ - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'loss default': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeyyyuyuyuyneration:train-loss", - "Regex": "'loss default': ([0-9]+\\.[0-9]+)", - }, - ], "estimator_kwargs": { "encrypt_inter_container_traffic": True, }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, - "predictor_specs": { - "supported_content_types": ["application/x-image"], - "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-image", - "default_accept_type": "application/json", - }, - "inference_volume_size": 123, "training_volume_size": 456, "inference_enable_network_isolation": True, "training_enable_network_isolation": False, - "resource_name_base": "dfsdfsds", }, - "model-artifact-variant-model": { + # noqa: E501 + "variant-model": { "model_id": "pytorch-ic-mobilenet-v2", "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", "version": "1.0.0", "min_sdk_version": "2.49.0", "training_supported": True, "incremental_training_supported": True, + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" + "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + }, "hosting_ecr_specs": { "framework": "pytorch", "framework_version": "1.5.0", "py_version": "py3", }, + "training_instance_type_variants": { + "regional_aliases": {}, + "variants": { + "ml.p2.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "hyperparameters": [ + { + "name": "eval_metric", + "type": "text", + "default": "auto", + "scope": "algorithm", + }, + { + "name": "presets", + "type": "text", + "default": "medium_quality", + "options": [ + "best_quality", + "high_quality", + "good_quality", + "medium_quality", + "optimize_for_deployment", + "interpretable", + ], + "scope": "algorithm", + }, + { + "name": "auto_stack", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "num_bag_folds", + "type": "text", + "default": "0", + "options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "scope": "algorithm", + }, + { + "name": "num_bag_sets", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "batch-size", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "num_stack_levels", + "type": "int", + "default": 0, + "min": 0, + "max": 3, + "scope": "algorithm", + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:instance-typemetric-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", + }, + ], + } + }, + "p2": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_2"}, + "properties": { + "hyperparameters": [ + { + "name": "num_bag_sets", + "type": "int", + "default": 5, + "min": 5, + "scope": "algorithm", + }, + { + "name": "num_stack_levels", + "type": "int", + "default": 6, + "min": 7, + "max": 3, + "scope": "algorithm", + }, + { + "name": "refit_full", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "set_best_to_refit_full", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_space", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "verbosity", + "type": "int", + "default": 2, + "min": 0, + "max": 4, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", + }, + ], + }, + }, + }, + }, "hosting_instance_type_variants": { "regional_aliases": { "us-west-2": { "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + "inf_model_package_arn": "us-west-2/blah/blah/blah/inf", + "gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu", } }, "variants": { "p2": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": {"prepacked_artifact_key": "hello-world-1"}, + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "p3": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "p4": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "g4dn": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "g5": { + "properties": { + "resource_requirements": { + "num_accelerators": 888810, + "randon-field-2": 2222, + } + } }, - "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "p4": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "ml.g5.xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}, + "resource_requirements": {"num_accelerators": 10}, + } + }, "ml.g5.48xlarge": { "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} }, "ml.g5.12xlarge": { "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} }, + "inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, + "inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, }, }, "training_ecr_specs": { @@ -3143,42 +3553,21 @@ "framework_version": "1.5.0", "py_version": "py3", }, - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - } - }, - "variants": { - "p2": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": {"artifact_key": "hello-mars-1"}, - }, - "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "p4": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "ml.g5.48xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} - }, - "ml.g5.12xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} - }, - }, + "dynamic_container_deployment_supported": True, + "hosting_resource_requirements": { + "min_memory_mb": 81999, + "num_accelerators": 1, + "random_field_1": 1, }, "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "hosting_script_key": None, + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", "training_prepacked_script_key": None, - "hosting_prepacked_artifact_key": "basfsdfssf", + "hosting_prepacked_artifact_key": None, "training_model_package_artifact_uris": None, "deprecate_warn_message": None, "deprecated_message": None, - "hosting_model_package_arns": None, "hosting_eula_key": None, "hyperparameters": [ { @@ -3308,8 +3697,17 @@ "ml.c5.2xlarge", ], "hosting_use_script_uri": True, - "metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}], - "model_kwargs": {}, + "metrics": [ + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeyyyuyuyuyneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + ], + "model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"}, "deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"}, "estimator_kwargs": { "encrypt_inter_container_traffic": True, @@ -3325,110 +3723,153 @@ "training_volume_size": 456, "inference_enable_network_isolation": True, "training_enable_network_isolation": False, + "resource_name_base": "dfsdfsds", }, - "private-model": { - "model_id": "pytorch-ic-mobilenet-v2", - "gated_bucket": True, - "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", + "gated_llama_neuron_model": { + "model_id": "meta-textgenerationneuron-llama-2-7b", + "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", "version": "1.0.0", - "min_sdk_version": "2.49.0", + "min_sdk_version": "2.198.0", "training_supported": True, - "incremental_training_supported": True, - "hosting_model_package_arns": { - "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" - "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" - }, + "incremental_training_supported": False, "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", - }, - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - "inf_model_package_arn": "us-west-2/blah/blah/blah/inf", - "gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu", - } - }, - "variants": { - "p2": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } - }, - "p3": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } - }, - "p4": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } - }, - "g4dn": { - "regional_properties": { - "image_uri": "$gpu_image_uri", - "model_package_arn": "$gpu_model_package_arn", - } - }, - "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, - "ml.g5.48xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} - }, - "ml.g5.12xlarge": { - "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} - }, - "inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, - "inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, - }, - }, - "training_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "framework": "djl-neuronx", + "framework_version": "0.24.0", + "py_version": "py39", }, - "training_instance_type_variants": None, - "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", - "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", - "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", - "training_prepacked_script_key": None, - "hosting_prepacked_artifact_key": None, - "training_model_package_artifact_uris": None, - "deprecate_warn_message": None, - "deprecated_message": None, - "hosting_eula_key": None, + "hosting_artifact_key": "meta-textgenerationneuron/meta-textgenerationneuron-llama-2-7b/artifac" + "ts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/meta/inference/textgenerationneuron/v1.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "meta-textgenerationneuron/meta-textgenerationneuro" + "n-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", + "inference_vulnerable": False, + "inference_dependencies": [ + "sagemaker_jumpstart_huggingface_script_utilities==1.0.8", + "sagemaker_jumpstart_script_utilities==1.1.8", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "sagemaker_jumpstart_huggingface_script_utilities==1.1.3", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + ], + "training_vulnerabilities": [], + "deprecated": False, "hyperparameters": [ { - "name": "epochs", + "name": "max_input_length", "type": "int", - "default": 3, - "min": 1, - "max": 1000, + "default": 2048, + "min": 128, "scope": "algorithm", }, { - "name": "adam-learning-rate", + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "learning_rate", "type": "float", - "default": 0.05, + "default": 6e-06, "min": 1e-08, "max": 1, "scope": "algorithm", }, { - "name": "batch-size", + "name": "min_learning_rate", + "type": "float", + "default": 1e-06, + "min": 1e-12, + "max": 1, + "scope": "algorithm", + }, + {"name": "max_steps", "type": "int", "default": 20, "min": 2, "scope": "algorithm"}, + { + "name": "global_train_batch_size", "type": "int", - "default": 4, + "default": 256, "min": 1, - "max": 1024, + "scope": "algorithm", + }, + { + "name": "per_device_train_batch_size", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "layer_norm_epilson", + "type": "float", + "default": 1e-05, + "min": 1e-12, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.1, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "CosineAnnealing", + "options": ["CosineAnnealing"], + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 10, "min": 0, "scope": "algorithm"}, + {"name": "constant_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.95, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "mixed_precision", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "tensor_parallel_degree", + "type": "text", + "default": "8", + "options": ["8"], + "scope": "algorithm", + }, + { + "name": "pipeline_parallel_degree", + "type": "text", + "default": "1", + "options": ["1"], + "scope": "algorithm", + }, + { + "name": "append_eod", + "type": "text", + "default": "False", + "options": ["True", "False"], "scope": "algorithm", }, { @@ -3450,6 +3891,18 @@ "scope": "container", }, ], + "training_script_key": "source-directory-tarballs/meta/transfer_learning/textgenerati" + "onneuron/v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/meta/tra" + "nsfer_learning/textgenerationneuron/prepack/v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_version": "1.0.0", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "meta-training/train-meta-textgenerationneuron-llama-2-7b.tar.gz", "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", @@ -3473,32 +3926,25 @@ "required_for_model_class": False, }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "name": "MODEL_CACHE_ROOT", "type": "text", - "default": "3600", + "default": "/opt/ml/model", "scope": "container", "required_for_model_class": False, }, { - "name": "ENDPOINT_SERVER_TIMEOUT", - "type": "int", - "default": 3600, + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", "scope": "container", - "required_for_model_class": True, - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "required_for_model_class": False, }, { - "name": "SAGEMAKER_ENV", + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", "type": "text", - "default": "1", + "default": "3600", "scope": "container", - "required_for_model_class": True, + "required_for_model_class": False, }, { "name": "SAGEMAKER_MODEL_SERVER_WORKERS", @@ -3508,96 +3954,32 @@ "required_for_model_class": True, }, ], - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "default_inference_instance_type": "ml.p2.xlarge", - "supported_inference_instance_types": [ - "ml.p2.xlarge", - "ml.p3.2xlarge", - "ml.g4dn.xlarge", - "ml.m5.large", - "ml.m5.xlarge", - "ml.c5.xlarge", - "ml.c5.2xlarge", - ], - "default_training_instance_type": "ml.p3.2xlarge", - "supported_training_instance_types": [ - "ml.p3.2xlarge", - "ml.p2.xlarge", - "ml.g4dn.2xlarge", - "ml.m5.xlarge", - "ml.c5.2xlarge", + "metrics": [ + { + "Name": "meta-textgenerationneuron:train-loss", + "Regex": "reduced_train_loss=([0-9]+\\.[0-9]+)", + } ], - "hosting_use_script_uri": True, - "metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}], - "model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"}, - "deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"}, - "estimator_kwargs": { - "encrypt_inter_container_traffic": True, - }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, - "predictor_specs": { - "supported_content_types": ["application/x-image"], - "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-image", - "default_accept_type": "application/json", - }, - "inference_volume_size": 123, - "training_volume_size": 456, - "inference_enable_network_isolation": True, - "training_enable_network_isolation": False, - "resource_name_base": "dfsdfsds", - }, - "js-model-package-arn": { - "model_id": "meta-textgeneration-llama-2-7b-f", - "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", - "version": "1.0.0", - "min_sdk_version": "2.173.0", - "training_supported": False, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.12.0", - "py_version": "py38", - }, - "hosting_artifact_key": "meta-infer/infer-meta-textgeneration-llama-2-7b-f.tar.gz", - "hosting_script_key": "source-directory-tarballs/meta/inference/textgeneration/v1.0.0/sourcedir.tar.gz", - "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", - "hosting_model_package_arns": { - "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/" - "llama2-7b-f-e46eb8a833643ed58aaccd81498972c3", - "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/" - "llama2-7b-f-e46eb8a833643ed58aaccd81498972c3", - }, - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "inference_environment_variables": [], - "metrics": [], - "default_inference_instance_type": "ml.g5.2xlarge", + "default_inference_instance_type": "ml.inf2.xlarge", "supported_inference_instance_types": [ - "ml.g5.2xlarge", - "ml.g5.4xlarge", - "ml.g5.8xlarge", - "ml.g5.12xlarge", - "ml.g5.24xlarge", - "ml.g5.48xlarge", - "ml.p4d.24xlarge", + "ml.inf2.xlarge", + "ml.inf2.8xlarge", + "ml.inf2.24xlarge", + "ml.inf2.48xlarge", ], + "default_training_instance_type": "ml.trn1.32xlarge", + "supported_training_instance_types": ["ml.trn1.32xlarge", "ml.trn1n.32xlarge"], "model_kwargs": {}, "deploy_kwargs": { "model_data_download_timeout": 3600, "container_startup_health_check_timeout": 3600, }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, + }, + "fit_kwargs": {}, "predictor_specs": { "supported_content_types": ["application/json"], "supported_accept_types": ["application/json"], @@ -3605,261 +3987,455 @@ "default_accept_type": "application/json", }, "inference_volume_size": 256, + "training_volume_size": 256, "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/sec_amazon/", "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "meta-textgeneration-llama-2-7b-f", - }, - "js-trainable-model-prepacked": { - "model_id": "huggingface-text2text-flan-t5-base", - "url": "https://huggingface.co/google/flan-t5-base", - "version": "1.2.0", - "min_sdk_version": "2.130.0", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "huggingface", - "framework_version": "1.10.2", - "py_version": "py38", - "huggingface_transformers_version": "4.17.0", - }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-base.tar.gz", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.4/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-" - "huggingface-text2text-flan-t5-base.tar.gz", - "hosting_prepacked_artifact_version": "1.0.0", - "inference_vulnerable": False, - "inference_dependencies": [ - "accelerate==0.16.0", - "bitsandbytes==0.37.0", - "filelock==3.9.0", - "huggingface_hub==0.12.0", - "regex==2022.7.9", - "tokenizers==0.13.2", - "transformers==4.26.0", - ], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [ - "Brotli==1.0.9", - "absl-py==1.4.0", - "accelerate==0.16.0", - "datasets==2.9.0", - "deepspeed==0.8.0", - "evaluate==0.4.0", - "hjson==3.1.0", - "huggingface_hub==0.13.3", - "inflate64==0.3.1", - "multivolumefile==0.2.3", - "ninja==1.11.1", - "nltk==3.8.1", - "psutil==5.9.4", - "py-cpuinfo==9.0.0", - "py7zr==0.20.4", - "pybcj==1.0.1", - "pycryptodomex==3.17", - "pydantic==1.10.2", - "pyppmd==1.0.0", - "pyzstd==0.15.4", - "rouge-score==0.1.2", - "sagemaker_jumpstart_script_utilities==1.1.4", - "sagemaker_jumpstart_tabular_script_utilities==1.0.0", - "tensorboardX==2.6", - "texttable==1.6.7", - "transformers==4.26.0", - ], - "training_vulnerabilities": [], - "deprecated": False, - "hyperparameters": [ - { - "name": "epochs", - "type": "int", - "default": 1, - "min": 1, - "max": 1000, - "scope": "algorithm", - }, - { - "name": "seed", - "type": "int", - "default": 42, - "min": 1, - "max": 1000, - "scope": "algorithm", - }, - { - "name": "batch_size", - "type": "int", - "default": 64, - "min": 1, - "max": 1024, - "scope": "algorithm", - }, - { - "name": "learning_rate", - "type": "float", - "default": 0.0001, - "min": 1e-08, - "max": 1, - "scope": "algorithm", - }, - { - "name": "validation_split_ratio", - "type": "float", - "default": 0.05, - "min": 0, - "max": 1, - "scope": "algorithm", + "fine_tuning_supported": True, + "resource_name_base": "meta-textgenerationneuron-llama-2-7b", + "default_payloads": { + "meaningOfLife": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "generated_text"}, + "body": { + "inputs": "I believe the meaning of life is", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, }, - {"name": "train_data_split_seed", "type": "int", "default": 0, "scope": "algorithm"}, - { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", + "theoryOfRelativity": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "generated_text"}, + "body": { + "inputs": "Simply put, the theory of relativity states that ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", + "teamMessage": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "generated_text"}, + "body": { + "inputs": "A brief message congratulating the team on the launch:\n\nHi " + "everyone,\n\nI just ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", + "englishToFrench": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "generated_text"}, + "body": { + "inputs": "Translate English to French:\nsea otter => loutre de mer\npep" + "permint => menthe poivr\u00e9e\nplush girafe => girafe peluche\ncheese =>", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, }, - ], - "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/text2text/" - "v1.1.0/sourcedir.tar.gz", - "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/" - "text2text/prepack/v1.0.1/sourcedir.tar.gz", - "training_prepacked_script_version": "1.0.1", - "training_ecr_specs": { - "framework": "huggingface", - "framework_version": "1.10.2", - "py_version": "py38", - "huggingface_transformers_version": "4.17.0", }, - "training_artifact_key": "huggingface-training/train-huggingface-text2text-flan-t5-base.tar.gz", - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "ENDPOINT_SERVER_TIMEOUT", - "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "SAGEMAKER_ENV", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/djl-in" + "ference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-east-1": { + "alias_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/djl-in" + "ference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-northeast-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/d" + "jl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-northeast-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-south-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-southeast-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-southeast-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "ca-central-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "cn-north-1": { + "alias_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-central-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-north-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-south-1": { + "alias_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-west-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/d" + "jl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-west-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/d" + "jl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-west-3": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/d" + "jl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "me-south-1": { + "alias_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "sa-east-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "us-east-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "us-east-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "us-west-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.co" + "m/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "us-west-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-" + "inference:0.24.0-neuronx-sdk2.14.1" + }, }, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, - "scope": "container", - "required_for_model_class": True, + "variants": { + "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "ml.inf2.xlarge": { + "properties": { + "environment_variables": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "OPTION_N_POSITIONS": "1024", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "1", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + } + }, + "ml.inf2.8xlarge": { + "properties": { + "environment_variables": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + } + }, + "ml.inf2.24xlarge": { + "properties": { + "environment_variables": { + "OPTION_TENSOR_PARALLEL_DEGREE": "12", + "OPTION_N_POSITIONS": "4096", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + } + }, + "ml.inf2.48xlarge": { + "properties": { + "environment_variables": { + "OPTION_TENSOR_PARALLEL_DEGREE": "24", + "OPTION_N_POSITIONS": "4096", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + } + }, }, - ], - "metrics": [ - {"Name": "huggingface-text2text:eval-loss", "Regex": "'eval_loss': ([0-9\\.]+)"} - ], - "default_inference_instance_type": "ml.g5.xlarge", - "supported_inference_instance_types": [ - "ml.g5.xlarge", - "ml.p2.xlarge", - "ml.g4dn.xlarge", - "ml.p3.2xlarge", - ], - "default_training_instance_type": "ml.p3.16xlarge", - "supported_training_instance_types": [ - "ml.p3.8xlarge", - "ml.p3.16xlarge", - "ml.p3dn.24xlarge", - "ml.g5.24xlarge", - "ml.g5.48xlarge", - ], - "model_kwargs": {}, - "deploy_kwargs": {}, - "estimator_kwargs": {"encrypt_inter_container_traffic": False}, - "fit_kwargs": {}, - "predictor_specs": { - "supported_content_types": ["application/x-text"], - "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-text", - "default_accept_type": "application/json", - }, - }, - "js-model-class-model-prepacked": { - "model_id": "huggingface-txt2img-conflictx-complex-lineart", - "url": "https://huggingface.co/Conflictx/Complex-Lineart", - "version": "1.1.0", - "min_sdk_version": "2.81.0", - "training_supported": False, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "huggingface", - "framework_version": "1.10.2", - "py_version": "py38", - "huggingface_transformers_version": "4.17.0", }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-txt2img-conflictx-complex-lineart.tar.gz", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/txt2img/v1.1.0/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-txt2img-" - "conflictx-complex-lineart.tar.gz", - "hosting_prepacked_artifact_version": "1.0.0", - "inference_vulnerable": False, - "inference_dependencies": [ - "accelerate==0.16.0", - "diffusers==0.12.1", - "huggingface_hub==0.12.0", - "transformers==4.26.0", + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorc" + "h-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch" + "-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-" + "pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingfa" + "ce-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch" + "-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/hu" + "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/hug" + "gingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/hu" + "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/hu" + "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/hug" + "gingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/hug" + "gingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggi" + "ngface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggin" + "gface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/hugg" + "ingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/hu" + "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + "neuron_ecr_uri": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-" + "training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-py" + "torch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + "neuron_ecr_uri": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-trai" + "ning-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytor" + "ch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface" + "-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + "neuron_ecr_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch" + "-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "trn1": { + "regional_properties": {"image_uri": "$neuron_ecr_uri"}, + "properties": { + "gated_model_key_env_var_value": "meta-training/trn1/v1.0." + "0/train-meta-textgenerationneuron-llama-2-7b.tar.gz" + }, + }, + "trn1n": { + "regional_properties": {"image_uri": "$neuron_ecr_uri"}, + "properties": { + "gated_model_key_env_var_value": "meta-training/trn1n/v1.0.0" + "/train-meta-textgenerationneuron-llama-2-7b.tar.gz" + }, + }, + }, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "hosting_resource_requirements": {"min_memory_mb": 8192, "num_accelerators": 1}, + "dynamic_container_deployment_supported": True, + }, + "gated_variant-model": { + "model_id": "pytorch-ic-mobilenet-v2", + "gated_bucket": True, + "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", + "version": "1.0.0", + "min_sdk_version": "2.49.0", + "training_supported": True, + "incremental_training_supported": True, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "training_instance_type_variants": None, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + } + }, + "variants": { + "p2": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + }, + "properties": { + "prepacked_artifact_key": "some-instance-specific/model/prefix/" + }, + }, + "p3": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + } + }, + "p4": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + } + }, + "g4dn": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + } + }, + "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "ml.g5.48xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} + }, + "ml.g5.12xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} + }, + }, + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": None, + "hosting_prepacked_artifact_key": None, + "training_model_package_artifact_uris": None, + "deprecate_warn_message": None, + "deprecated_message": None, + "hosting_eula_key": None, + "hyperparameters": [ + { + "name": "epochs", + "type": "int", + "default": 3, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "adam-learning-rate", + "type": "float", + "default": 0.05, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "batch-size", + "type": "int", + "default": 4, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, ], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", @@ -3918,181 +4494,149 @@ "required_for_model_class": True, }, ], - "metrics": [], - "default_inference_instance_type": "ml.p3.2xlarge", - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge", "ml.g4dn.xlarge"], - "model_kwargs": {}, - "deploy_kwargs": {}, - "predictor_specs": { - "supported_content_types": ["application/json"], + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": [ + "ml.p2.xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + ], + "default_training_instance_type": "ml.p3.2xlarge", + "supported_training_instance_types": [ + "ml.p3.2xlarge", + "ml.p2.xlarge", + "ml.g4dn.2xlarge", + "ml.m5.xlarge", + "ml.c5.2xlarge", + ], + "hosting_use_script_uri": False, + "metrics": [ + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeyyyuyuyuyneration:train-loss", + "Regex": "'loss default': ([0-9]+\\.[0-9]+)", + }, + ], + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + }, + "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, + "predictor_specs": { + "supported_content_types": ["application/x-image"], "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/json", + "default_content_type": "application/x-image", "default_accept_type": "application/json", }, + "inference_volume_size": 123, + "training_volume_size": 456, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": False, + "resource_name_base": "dfsdfsds", }, - "deprecated_model": { - "model_id": "huggingface-text2text-flan-t5-base", - "url": "https://huggingface.co/google/flan-t5-base", - "version": "1.2.0", - "min_sdk_version": "2.130.0", + "model-artifact-variant-model": { + "model_id": "pytorch-ic-mobilenet-v2", + "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", + "version": "3.0.6", + "min_sdk_version": "2.189.0", "training_supported": True, - "incremental_training_supported": False, + "incremental_training_supported": True, "hosting_ecr_specs": { - "framework": "huggingface", - "framework_version": "1.10.2", + "framework": "pytorch", + "framework_version": "1.10.0", "py_version": "py38", - "huggingface_transformers_version": "4.17.0", }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-base.tar.gz", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.4/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-" - "huggingface-text2text-flan-t5-base.tar.gz", + "hosting_artifact_key": "pytorch-ic/pytorch-ic-mobilenet-v2/artifacts/inference/v2.0.0/", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v2.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "pytorch-ic/pytorch-ic-mobilenet-v2/artifacts/inference-prepack/v1.0.0/", "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, "inference_vulnerable": False, "inference_dependencies": [], "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": [], + "training_dependencies": ["sagemaker_jumpstart_prepack_script_utilities==1.0.0"], "training_vulnerabilities": [], - "deprecated": True, - "hyperparameters": [], - "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/text2text/" - "v1.1.0/sourcedir.tar.gz", - "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/" - "text2text/prepack/v1.0.1/sourcedir.tar.gz", - "training_prepacked_script_version": "1.0.1", - "training_ecr_specs": { - "framework": "huggingface", - "framework_version": "1.10.2", - "py_version": "py38", - "huggingface_transformers_version": "4.17.0", - }, - "training_artifact_key": "huggingface-training/train-huggingface-text2text-flan-t5-base.tar.gz", - "inference_environment_variables": [ + "deprecated": False, + "hyperparameters": [ { - "name": "SAGEMAKER_PROGRAM", + "name": "train_only_top_layer", "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + "name": "epochs", + "type": "int", + "default": 5, + "min": 1, + "max": 1000, + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "name": "learning_rate", + "type": "float", + "default": 0.001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, + "name": "batch_size", + "type": "int", + "default": 4, + "min": 1, + "max": 1024, + "scope": "algorithm", }, { - "name": "ENDPOINT_SERVER_TIMEOUT", - "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, + "name": "reinitialize_top_layer", + "type": "text", + "default": "Auto", + "options": ["Auto", "True", "False"], + "scope": "algorithm", }, { - "name": "MODEL_CACHE_ROOT", + "name": "sagemaker_submit_directory", "type": "text", - "default": "/opt/ml/model", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", - "required_for_model_class": True, }, { - "name": "SAGEMAKER_ENV", + "name": "sagemaker_program", "type": "text", - "default": "1", + "default": "transfer_learning.py", "scope": "container", - "required_for_model_class": True, }, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", "scope": "container", - "required_for_model_class": True, }, ], - "metrics": [ - {"Name": "huggingface-text2text:eval-loss", "Regex": "'eval_loss': ([0-9\\.]+)"} - ], - "default_inference_instance_type": "ml.g5.xlarge", - "supported_inference_instance_types": [ - "ml.g5.xlarge", - "ml.p2.xlarge", - "ml.g4dn.xlarge", - "ml.p3.2xlarge", - ], - "default_training_instance_type": "ml.p3.16xlarge", - "supported_training_instance_types": [ - "ml.p3.8xlarge", - "ml.p3.16xlarge", - "ml.p3dn.24xlarge", - "ml.g5.24xlarge", - "ml.g5.48xlarge", - ], - "model_kwargs": {}, - "deploy_kwargs": {}, - "estimator_kwargs": {"encrypt_inter_container_traffic": False}, - "fit_kwargs": {}, - "predictor_specs": { - "supported_content_types": ["application/x-text"], - "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-text", - "default_accept_type": "application/json", - }, - }, - "vulnerable_model": { - "model_id": "huggingface-text2text-flan-t5-base", - "url": "https://huggingface.co/google/flan-t5-base", - "version": "1.2.0", - "min_sdk_version": "2.130.0", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "huggingface", - "framework_version": "1.10.2", - "py_version": "py38", - "huggingface_transformers_version": "4.17.0", - }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-base.tar.gz", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.4/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-" - "huggingface-text2text-flan-t5-base.tar.gz", - "hosting_prepacked_artifact_version": "1.0.0", - "inference_vulnerable": True, - "inference_dependencies": ["blah"], - "inference_vulnerabilities": ["blah"], - "training_vulnerable": True, - "training_dependencies": ["blah"], - "training_vulnerabilities": ["blah"], - "deprecated": False, - "hyperparameters": [], - "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/text2text/" - "v1.1.0/sourcedir.tar.gz", - "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/" - "text2text/prepack/v1.0.1/sourcedir.tar.gz", - "training_prepacked_script_version": "1.0.1", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v2.3.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/prepack/v1.1.0/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.0", "training_ecr_specs": { - "framework": "huggingface", - "framework_version": "1.10.2", + "framework": "pytorch", + "framework_version": "1.10.0", "py_version": "py38", - "huggingface_transformers_version": "4.17.0", }, - "training_artifact_key": "huggingface-training/train-huggingface-text2text-flan-t5-base.tar.gz", + "training_artifact_key": "pytorch-training/v2.0.0/train-pytorch-ic-mobilenet-v2.tar.gz", "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", @@ -4151,220 +4695,487 @@ "required_for_model_class": True, }, ], - "metrics": [ - {"Name": "huggingface-text2text:eval-loss", "Regex": "'eval_loss': ([0-9\\.]+)"} - ], - "default_inference_instance_type": "ml.g5.xlarge", + "metrics": [{"Name": "pytorch-ic:val-accuracy", "Regex": "val_accuracy: ([0-9\\.]+)"}], + "default_inference_instance_type": "ml.m5.large", "supported_inference_instance_types": [ - "ml.g5.xlarge", - "ml.p2.xlarge", - "ml.g4dn.xlarge", - "ml.p3.2xlarge", - ], - "default_training_instance_type": "ml.p3.16xlarge", - "supported_training_instance_types": [ - "ml.p3.8xlarge", - "ml.p3.16xlarge", - "ml.p3dn.24xlarge", - "ml.g5.24xlarge", - "ml.g5.48xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + "ml.m4.large", + "ml.m4.xlarge", ], + "default_training_instance_type": "ml.m5.xlarge", + "supported_training_instance_types": ["ml.m5.xlarge", "ml.c5.2xlarge", "ml.m4.xlarge"], "model_kwargs": {}, "deploy_kwargs": {}, - "estimator_kwargs": {"encrypt_inter_container_traffic": False}, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, "fit_kwargs": {}, "predictor_specs": { - "supported_content_types": ["application/x-text"], + "supported_content_types": ["application/x-image"], "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-text", + "default_content_type": "application/x-image", "default_accept_type": "application/json", }, - }, - "js-gated-artifact-non-model-package-trainable-model": { - "model_id": "meta-textgeneration-llama-2-7b", - "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", - "version": "3.0.0", - "min_sdk_version": "2.189.0", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "huggingface-llm", - "framework_version": "1.1.0", - "py_version": "py39", - }, - "training_artifact_key": "some/dummy/key", - "hosting_artifact_key": "meta-textgeneration/meta-textgeneration-llama-2-7b/artifacts/inference/v1.0.0/", - "hosting_script_key": "source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "meta-textgeneration/meta-textgen" - "eration-llama-2-7b/artifacts/inference-prepack/v1.0.0/", - "hosting_prepacked_artifact_version": "1.0.0", - "hosting_use_script_uri": False, - "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", - "inference_vulnerable": False, - "inference_dependencies": [ - "sagemaker_jumpstart_huggingface_script_utilities==1.0.8", - "sagemaker_jumpstart_script_utilities==1.1.8", - ], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [ - "accelerate==0.21.0", - "bitsandbytes==0.39.1", - "black==23.7.0", - "brotli==1.0.9", - "datasets==2.14.1", - "fire==0.5.0", - "inflate64==0.3.1", - "loralib==0.1.1", - "multivolumefile==0.2.3", - "mypy-extensions==1.0.0", - "pathspec==0.11.1", - "peft==0.4.0", - "py7zr==0.20.5", - "pybcj==1.0.1", - "pycryptodomex==3.18.0", - "pyppmd==1.0.0", - "pytorch-triton==2.1.0+e6216047b8", - "pyzstd==0.15.9", - "safetensors==0.3.1", - "sagemaker_jumpstart_huggingface_script_utilities==1.1.3", - "sagemaker_jumpstart_script_utilities==1.1.9", - "scipy==1.11.1", - "termcolor==2.3.0", - "texttable==1.6.7", - "tokenize-rt==5.1.0", - "tokenizers==0.13.3", - "torch==2.1.0.dev20230905+cu118", - "transformers==4.31.0", - ], - "training_vulnerabilities": [], - "deprecated": False, - "hyperparameters": [ - { - "name": "int8_quantization", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "enable_fsdp", - "type": "text", - "default": "True", - "options": ["True", "False"], - "scope": "algorithm", + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/tf_flowers/", + "validation_supported": False, + "fine_tuning_supported": True, + "resource_name_base": "pt-ic-mobilenet-v2", + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-northeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-south-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-inference:1.10.0-gpu-py38", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-central-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-west-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "sa-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-east-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-west-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, }, - { - "name": "epoch", - "type": "int", - "default": 5, - "min": 1, - "max": 1000, - "scope": "algorithm", + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_2"}, + "properties": {"prepacked_artifact_key": "hello-world-1"}, + }, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, }, - { - "name": "learning_rate", - "type": "float", - "default": 0.0001, - "min": 1e-08, - "max": 1, - "scope": "algorithm", + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-northeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-south-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-training:1.10.0-gpu-py38", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.10.0-gpu-py38", + }, + "eu-central-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-west-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "sa-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-east-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-west-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, }, - {"name": "lora_r", "type": "int", "default": 8, "min": 1, "scope": "algorithm"}, - {"name": "lora_alpha", "type": "int", "default": 32, "min": 1, "scope": "algorithm"}, - { - "name": "lora_dropout", - "type": "float", - "default": 0.05, - "min": 0, - "max": 1, - "scope": "algorithm", + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": { + "regional_properties": {"image_uri": "$cpu_ecr_uri_1"}, + "properties": {"training_artifact_key": "hello-world-1"}, + }, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, }, - { - "name": "instruction_tuned", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, + }, + "private-model": { + "model_id": "pytorch-ic-mobilenet-v2", + "gated_bucket": True, + "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", + "version": "1.0.0", + "min_sdk_version": "2.49.0", + "training_supported": True, + "incremental_training_supported": True, + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" + "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + }, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + "inf_model_package_arn": "us-west-2/blah/blah/blah/inf", + "gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu", + } }, - { - "name": "chat_dataset", - "type": "text", - "default": "False", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "add_input_output_demarcation_key", - "type": "text", - "default": "True", - "options": ["True", "False"], - "scope": "algorithm", - }, - { - "name": "per_device_train_batch_size", - "type": "int", - "default": 4, - "min": 1, - "max": 1000, - "scope": "algorithm", - }, - { - "name": "per_device_eval_batch_size", - "type": "int", - "default": 1, - "min": 1, - "max": 1000, - "scope": "algorithm", - }, - { - "name": "max_train_samples", - "type": "int", - "default": -1, - "min": -1, - "scope": "algorithm", - }, - { - "name": "max_val_samples", - "type": "int", - "default": -1, - "min": -1, - "scope": "algorithm", + "variants": { + "p2": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "p3": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "p4": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "g4dn": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "ml.g5.48xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} + }, + "ml.g5.12xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} + }, + "inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, + "inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, }, + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "training_instance_type_variants": None, + "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": None, + "hosting_prepacked_artifact_key": None, + "training_model_package_artifact_uris": None, + "deprecate_warn_message": None, + "deprecated_message": None, + "hosting_eula_key": None, + "hyperparameters": [ { - "name": "seed", + "name": "epochs", "type": "int", - "default": 10, + "default": 3, "min": 1, "max": 1000, "scope": "algorithm", }, { - "name": "max_input_length", - "type": "int", - "default": -1, - "min": -1, - "scope": "algorithm", - }, - { - "name": "validation_split_ratio", + "name": "adam-learning-rate", "type": "float", - "default": 0.2, - "min": 0, + "default": 0.05, + "min": 1e-08, "max": 1, "scope": "algorithm", }, { - "name": "train_data_split_seed", + "name": "batch-size", "type": "int", - "default": 0, - "min": 0, - "scope": "algorithm", - }, - { - "name": "preprocessing_num_workers", - "type": "text", - "default": "None", + "default": 4, + "min": 1, + "max": 1024, "scope": "algorithm", }, { @@ -4386,17 +5197,6 @@ "scope": "container", }, ], - "training_script_key": "source-directory-tarballs/" - "meta/transfer_learning/textgeneration/v1.0.4/sourcedir.tar.gz", - "training_prepacked_script_key": "source-directory-" - "tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz", - "training_prepacked_script_version": "1.0.1", - "training_ecr_specs": { - "framework": "huggingface", - "framework_version": "2.0.0", - "py_version": "py310", - "huggingface_transformers_version": "4.28.1", - }, "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", @@ -4447,34 +5247,6 @@ "scope": "container", "required_for_model_class": True, }, - { - "name": "HF_MODEL_ID", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "MAX_INPUT_LENGTH", - "type": "text", - "default": "4095", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "MAX_TOTAL_TOKENS", - "type": "text", - "default": "4096", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "SM_NUM_GPUS", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, - }, { "name": "SAGEMAKER_MODEL_SERVER_WORKERS", "type": "int", @@ -4483,356 +5255,249 @@ "required_for_model_class": True, }, ], - "metrics": [ - { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", - }, - { - "Name": "huggingface-textgeneration:eval-ppl", - "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", - }, - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "train_epoch_loss=([0-9\\.]+)", - }, - ], - "default_inference_instance_type": "ml.g5.2xlarge", + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "default_inference_instance_type": "ml.p2.xlarge", "supported_inference_instance_types": [ - "ml.g5.2xlarge", - "ml.g5.4xlarge", - "ml.g5.8xlarge", - "ml.g5.12xlarge", - "ml.g5.24xlarge", - "ml.g5.48xlarge", - "ml.p4d.24xlarge", + "ml.p2.xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", ], - "default_training_instance_type": "ml.g5.12xlarge", + "default_training_instance_type": "ml.p3.2xlarge", "supported_training_instance_types": [ - "ml.g5.12xlarge", - "ml.g5.24xlarge", - "ml.g5.48xlarge", - "ml.p3dn.24xlarge", + "ml.p3.2xlarge", + "ml.p2.xlarge", + "ml.g4dn.2xlarge", + "ml.m5.xlarge", + "ml.c5.2xlarge", ], - "model_kwargs": {}, - "deploy_kwargs": { - "model_data_download_timeout": 1200, - "container_startup_health_check_timeout": 1200, + "hosting_use_script_uri": True, + "metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}], + "model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"}, + "deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"}, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, }, - "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, - "fit_kwargs": {}, + "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", + "supported_content_types": ["application/x-image"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "application/x-image", "default_accept_type": "application/json", }, - "inference_volume_size": 256, - "training_volume_size": 256, + "inference_volume_size": 123, + "training_volume_size": 456, "inference_enable_network_isolation": True, - "training_enable_network_isolation": True, - "default_training_dataset_key": "training-datasets/sec_amazon/", - "validation_supported": True, - "fine_tuning_supported": True, - "resource_name_base": "meta-textgeneration-llama-2-7b", - "default_payloads": { - "meaningOfLife": { - "content_type": "application/json", - "prompt_key": "inputs", - "output_keys": {"generated_text": "[0].generated_text"}, - "body": { - "inputs": "I believe the meaning of life is", - "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, - }, - }, - "theoryOfRelativity": { - "content_type": "application/json", - "prompt_key": "inputs", - "output_keys": {"generated_text": "[0].generated_text"}, - "body": { - "inputs": "Simply put, the theory of relativity states that ", - "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, - }, - }, - "teamMessage": { - "content_type": "application/json", - "prompt_key": "inputs", - "output_keys": {"generated_text": "[0].generated_text"}, - "body": { - "inputs": "A brief message congratulating the team on the launch:\n\nHi everyone,\n\nI just ", - "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, - }, - }, - "englishToFrench": { - "content_type": "application/json", - "prompt_key": "inputs", - "output_keys": {"generated_text": "[0].generated_text"}, - "body": { - "inputs": "Translate English to French:\nsea o" - "tter => loutre de mer\npeppermint => ment" - "he poivr\u00e9e\nplush girafe => girafe peluche\ncheese =>", - "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, - }, - }, - "Story": { - "content_type": "application/json", - "prompt_key": "inputs", - "output_keys": { - "generated_text": "[0].generated_text", - "input_logprobs": "[0].details.prefill[*].logprob", - }, - "body": { - "inputs": "Please tell me a story.", - "parameters": { - "max_new_tokens": 64, - "top_p": 0.9, - "temperature": 0.2, - "decoder_input_details": True, - "details": True, - }, - }, - }, - }, - "gated_bucket": True, - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/h" - "uggingface-pytorch-tgi-inference:2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04" - }, - }, - "variants": { - "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, - "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, - "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, - "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, - }, - }, - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazon" - "aws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - }, - }, - "variants": { - "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "g5": { - "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, - "properties": { - "gated_model_key_env_var_value": "meta-training/train-meta-textgeneration-llama-2-7b.tar.gz", - "environment_variables": {"SELF_DESTRUCT": "true"}, - }, - }, - "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - }, - }, - "dynamic_container_deployment_supported": False, + "training_enable_network_isolation": False, + "resource_name_base": "dfsdfsds", }, - "js-gated-artifact-trainable-model": { + "js-model-package-arn": { "model_id": "meta-textgeneration-llama-2-7b-f", "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", - "version": "2.0.0", - "min_sdk_version": "2.173.0", + "version": "2.0.4", + "min_sdk_version": "2.174.0", "training_supported": True, "incremental_training_supported": False, "hosting_ecr_specs": { "framework": "djl-deepspeed", - "framework_version": "0.21.0", + "framework_version": "0.23.0", "py_version": "py39", }, "hosting_artifact_key": "meta-infer/infer-meta-textgeneration-llama-2-7b-f.tar.gz", + "hosting_script_key": "source-directory-tarballs/meta/inference/textgeneration/v1.2.2/sourcedir.tar.gz", "hosting_use_script_uri": False, - "hosting_script_key": "source-directory-tarballs/meta/inference/textgeneration/v1.0.0/sourcedir.tar.gz", "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", "hosting_model_package_arns": { - "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/" - "llama2-7b-f-e46eb8a833643ed58aaccd81498972c3", - "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/" - "llama2-7b-f-e46eb8a833643ed58aaccd81498972c3", - "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/" - "llama2-7b-f-e46eb8a833643ed58aaccd81498972c3", - "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/" - "llama2-7b-f-e46eb8a833643ed58aaccd81498972c3", + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", }, "training_model_package_artifact_uris": { - "us-west-2": "s3://jumpstart-cache-alpha-us-west-2/dummy.tar.gz", - "us-east-1": "s3://jumpstart-cache-alpha-us-west-2/dummy.tar.gz", - "eu-west-1": "s3://jumpstart-cache-alpha-us-west-2/dummy.tar.gz", - "ap-southeast-1": "s3://jumpstart-cache-alpha-us-west-2/dummy.tar.gz", + "us-west-2": "s3://sagemaker-repository-pdx/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "us-east-2": "s3://sagemaker-repository-cmh/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "us-east-1": "s3://sagemaker-repository-iad/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "eu-west-1": "s3://sagemaker-repository-dub/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "ap-southeast-1": "s3://sagemaker-repository-sin/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "ap-southeast-2": "s3://sagemaker-repository-syd/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", }, "inference_vulnerable": False, - "inference_dependencies": [], + "inference_dependencies": [ + "sagemaker_jumpstart_huggingface_script_utilities==1.0.8", + "sagemaker_jumpstart_script_utilities==1.1.8", + ], "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], + "training_vulnerable": True, + "training_dependencies": [ + "accelerate==0.21.0", + "bitsandbytes==0.39.1", + "black==23.7.0", + "brotli==1.0.9", + "datasets==2.14.1", + "fire==0.5.0", + "inflate64==0.3.1", + "loralib==0.1.1", + "multivolumefile==0.2.3", + "mypy-extensions==1.0.0", + "pathspec==0.11.1", + "peft==0.4.0", + "py7zr==0.20.5", + "pybcj==1.0.1", + "pycryptodomex==3.18.0", + "pyppmd==1.0.0", + "pytorch-triton==2.1.0+6e4932cda8", + "pyzstd==0.15.9", + "safetensors==0.3.1", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.3", + "sagemaker_jumpstart_script_utilities==1.1.9", + "scipy==1.11.1", + "termcolor==2.3.0", + "texttable==1.6.7", + "tokenize-rt==5.1.0", + "tokenizers==0.13.3", + "torch==2.2.0.dev20231104+cu118", + "transformers==4.31.0", + ], + "training_vulnerabilities": ["transformers==4.31.0"], "deprecated": False, + "deprecate_warn_message": "For forward compatibility, pin to model_version='2.*' in your JumpStartModel or JumpStartEstimator definitions. Note that major version upgrades may have different EULA acceptance terms and input/output signatures.", "hyperparameters": [ { - "name": "sagemaker_submit_directory", + "name": "int8_quantization", "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "sagemaker_program", + "name": "enable_fsdp", "type": "text", - "default": "transfer_learning.py", - "scope": "container", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", + "name": "epoch", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", }, - ], - "training_script_key": "source-directory-tarballs/meta/transfer_learning/" - "textgeneration/v1.0.0/sourcedir.tar.gz", - "training_ecr_specs": { - "framework": "djl-deepspeed", - "framework_version": "0.21.0", - "py_version": "py39", - }, - "training_artifact_key": "meta-training/train-meta-textgeneration-llama-2-7b-f.tar.gz", - "inference_environment_variables": [], - "metrics": [], - "default_inference_instance_type": "ml.g5.2xlarge", - "supported_inference_instance_types": [ - "ml.g5.2xlarge", - "ml.g5.4xlarge", - "ml.g5.8xlarge", - "ml.g5.12xlarge", - "ml.g5.24xlarge", - "ml.g5.48xlarge", - "ml.p4d.24xlarge", - ], - "default_training_instance_type": "ml.p3.2xlarge", - "supported_training_instance_types": ["ml.p3.2xlarge", "ml.p2.8xlarge", "ml.g4dn.xlarge"], - "model_kwargs": {}, - "deploy_kwargs": { - "model_data_download_timeout": 3600, - "container_startup_health_check_timeout": 3600, - }, - "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, - "fit_kwargs": {}, - "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", - "default_accept_type": "application/json", - }, - "inference_volume_size": 256, - "inference_enable_network_isolation": True, - "training_enable_network_isolation": True, - "default_training_dataset_key": "training-datasets/wikitext/", - "validation_supported": False, - "fine_tuning_supported": True, - "resource_name_base": "meta-textgeneration-llama-2-7b-f", - }, - "js-trainable-model": { - "model_id": "autogluon-classification-ensemble", - "url": "https://auto.gluon.ai/stable/index.html", - "version": "1.1.1", - "min_sdk_version": "2.103.0", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "autogluon", - "framework_version": "0.4.3", - "py_version": "py38", - }, - "hosting_artifact_key": "autogluon-infer/v1.1.0/infer-autogluon-classification-ensemble.tar.gz", - "hosting_script_key": "source-directory-tarballs/autogluon/inference/classification/v1.0.0/sourcedir.tar.gz", - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": ["sagemaker_jumpstart_script_utilities==1.0.1"], - "training_vulnerabilities": [], - "deprecated": False, - "hyperparameters": [ - {"name": "eval_metric", "type": "text", "default": "auto", "scope": "algorithm"}, { - "name": "presets", - "type": "text", - "default": "medium_quality", - "options": [ - "best_quality", - "high_quality", - "good_quality", - "medium_quality", - "optimize_for_deployment", - "interpretable", - ], + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + {"name": "lora_r", "type": "int", "default": 8, "min": 1, "scope": "algorithm"}, + {"name": "lora_alpha", "type": "int", "default": 32, "min": 1, "scope": "algorithm"}, + { + "name": "lora_dropout", + "type": "float", + "default": 0.05, + "min": 0, + "max": 1, "scope": "algorithm", }, { - "name": "auto_stack", + "name": "instruction_tuned", "type": "text", "default": "False", "options": ["True", "False"], "scope": "algorithm", }, { - "name": "num_bag_folds", + "name": "chat_dataset", "type": "text", - "default": "0", - "options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"], - "scope": "algorithm", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, - {"name": "num_bag_sets", "type": "int", "default": 1, "min": 1, "scope": "algorithm"}, { - "name": "num_stack_levels", + "name": "add_input_output_demarcation_key", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "per_device_train_batch_size", "type": "int", - "default": 0, - "min": 0, - "max": 3, + "default": 1, + "min": 1, + "max": 1000, "scope": "algorithm", }, { - "name": "refit_full", - "type": "text", - "default": "False", - "options": ["True", "False"], + "name": "per_device_eval_batch_size", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, "scope": "algorithm", }, { - "name": "set_best_to_refit_full", - "type": "text", - "default": "False", - "options": ["True", "False"], + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, "scope": "algorithm", }, { - "name": "save_space", - "type": "text", - "default": "False", - "options": ["True", "False"], + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, "scope": "algorithm", }, { - "name": "verbosity", + "name": "seed", "type": "int", - "default": 2, + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, "min": 0, - "max": 4, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", "scope": "algorithm", }, { @@ -4854,610 +5519,5605 @@ "scope": "container", }, ], - "training_script_key": "source-directory-tarballs/autogluon/transfer_learning/classification/" - "v1.0.2/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/meta/transfer_learning/textgeneration/v1.0.6/sourcedir.tar.gz", "training_ecr_specs": { - "framework": "autogluon", - "framework_version": "0.4.3", - "py_version": "py38", + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", }, - "training_artifact_key": "autogluon-training/train-autogluon-classification-ensemble.tar.gz", - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, - }, - { - "name": "ENDPOINT_SERVER_TIMEOUT", - "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, - }, + "training_artifact_key": "meta-training/train-meta-textgeneration-llama-2-7b-f.tar.gz", + "inference_environment_variables": [], + "metrics": [ { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", }, { - "name": "SAGEMAKER_ENV", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", }, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, - "scope": "container", - "required_for_model_class": True, + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", }, ], - "metrics": [], - "default_inference_instance_type": "ml.p2.xlarge", + "default_inference_instance_type": "ml.g5.2xlarge", "supported_inference_instance_types": [ - "ml.p2.xlarge", - "ml.p3.2xlarge", - "ml.g4dn.xlarge", - "ml.m5.2xlarge", - "ml.m5.4xlarge", - "ml.m5.12xlarge", - "ml.m5.24xlarge", - "ml.c5.2xlarge", - "ml.c5.4xlarge", - "ml.c5.9xlarge", - "ml.c5.18xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", ], - "default_training_instance_type": "ml.p3.2xlarge", + "default_training_instance_type": "ml.g5.12xlarge", "supported_training_instance_types": [ - "ml.m5.xlarge", - "ml.m5.2xlarge", - "ml.m5.4xlarge", - "ml.m5.12xlarge", - "ml.m5.24xlarge", - "ml.c5.2xlarge", - "ml.c5.4xlarge", - "ml.c5.9xlarge", - "ml.c5.18xlarge", - "ml.p2.xlarge", - "ml.p3.2xlarge", - "ml.g4dn.xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", ], "model_kwargs": {}, - "deploy_kwargs": {}, - "estimator_kwargs": {"encrypt_inter_container_traffic": True}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, + }, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, "fit_kwargs": {}, "predictor_specs": { - "supported_content_types": ["text/csv"], + "supported_content_types": ["application/json"], "supported_accept_types": ["application/json"], - "default_content_type": "text/csv", + "default_content_type": "application/json", "default_accept_type": "application/json", }, - "resource_name_base": "blahblahblah", + "inference_volume_size": 256, + "training_volume_size": 256, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/oasst_top/train/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "meta-textgeneration-llama-2-7b-f", + "default_payloads": { + "Mayo": { + "content_type": "application/json", + "body": { + "inputs": [[{"role": "user", "content": "what is the recipe of mayonnaise?"}]], + "parameters": {"max_new_tokens": 512, "top_p": 0.9, "temperature": 0.6}, + }, + } + }, + "gated_bucket": True, "hosting_instance_type_variants": { "regional_aliases": { "af-south-1": { - "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1" - ".amazonaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1" - ".amazonaws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "ap-east-1": { - "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1." - "amazonaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1." - "amazonaws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "ap-northeast-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-" - "1.amazonaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-" - "1.amazonaws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "ap-northeast-2": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2" - ".amazonaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2" - ".amazonaws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "ap-south-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.ama" - "zonaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazo" - "naws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "ap-southeast-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazo" - "naws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.a" - "mazonaws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "ap-southeast-2": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.am" - "azonaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazon" - "aws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "ca-central-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amaz" - "onaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazon" - "aws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "cn-north-1": { - "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaw" - "s.com.cn/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws" - ".com.cn/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "eu-central-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazona" - "ws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaw" - "s.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "eu-north-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazona" - "ws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazona" - "ws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "eu-south-1": { - "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amaz" - "onaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazo" - "naws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "eu-west-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1." - "amazonaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.ama" - "zonaws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "eu-west-2": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazona" - "ws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaw" - "s.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "eu-west-3": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws" - ".com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amaz" - "onaws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "me-south-1": { - "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amaz" - "onaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazo" - "naws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "sa-east-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amaz" - "onaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.ama" - "zonaws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "us-east-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.ama" - "zonaws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.am" - "azonaws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "us-east-2": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazo" - "naws.com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazon" - "aws.com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "us-west-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws." - "com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws." - "com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, "us-west-2": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws" - ".com/autogluon-inference:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws." - "com/autogluon-inference:0.4.3-gpu-py38", + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" }, }, "variants": { - "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, }, }, "training_instance_type_variants": { "regional_aliases": { "af-south-1": { - "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south" - "-1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amaz" - "naws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "ap-east-1": { - "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-" - "1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amaz" - "onaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "ap-northeast-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.ama" - "zonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.a" - "mazonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "ap-northeast-2": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2." - "amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2." - "amazonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "ap-south-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-sou" - "th-1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.ama" - "zonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "ap-southeast-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-" - "1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast" - "-1.amazonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "ap-southeast-2": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast" - "-2.amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2." - "amazonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "ca-central-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.a" - "mazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.am" - "azonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "eu-central-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.a" - "mazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.am" - "azonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "eu-north-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.am" - "azonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazon" - "aws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "eu-south-1": { - "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amaz" - "onaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.ama" - "zonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "eu-west-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amaz" - "onaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazo" - "naws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "eu-west-2": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.am" - "azonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.am" - "azonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "eu-west-3": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3." - "amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3." - "amazonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "me-south-1": { - "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-sout" - "h-1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1." - "amazonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "sa-east-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1." - "amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1" - ".amazonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "us-east-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1." - "amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1." - "amazonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "us-east-2": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2" - ".amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-" - "2.amazonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "us-west-1": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west" - "-1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-wes" - "t-1.amazonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, "us-west-2": { - "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west" - "-2.amazonaws.com/autogluon-training:0.4.3-cpu-py38", - "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-w" - "est-2.amazonaws.com/autogluon-training:0.4.3-gpu-py38", + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" }, }, - "variants": { - "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, - "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, - "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "dynamic_container_deployment_supported": False, + }, + "js-trainable-model-prepacked": { + "model_id": "huggingface-text2text-flan-t5-base", + "url": "https://huggingface.co/google/flan-t5-base", + "version": "2.2.3", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "1.4.0", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", + }, + "hosting_artifact_key": "huggingface-text2text/huggingface-text2text-flan-t5-base/artifacts/inference/v2.0.0/", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v2.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-text2text/huggingface-text2text-flan-t5-base/artifacts/inference-prepack/v2.0.0/", + "hosting_prepacked_artifact_version": "2.0.0", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "accelerate==0.23.0", + "datasets==2.12.0", + "deepspeed==0.10.3", + "peft==0.5.0", + "safetensors==0.3.3", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.2", + "sagemaker_jumpstart_script_utilities==1.1.8", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "epochs", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, + { + "name": "seed", + "type": "int", + "default": 42, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "batch_size", + "type": "int", + "default": 64, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.05, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "train_data_split_seed", "type": "int", "default": 0, "scope": "algorithm"}, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_eval_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_output_length", + "type": "int", + "default": 128, + "min": 0, + "scope": "algorithm", + }, + { + "name": "pad_to_max_length", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "gradient_accumulation_steps", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "load_best_model_at_end", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "early_stopping_patience", + "type": "int", + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_strategy", + "type": "text", + "default": "steps", + "options": ["no", "steps", "epoch"], + "scope": "algorithm", + }, + { + "name": "logging_first_step", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "logging_steps", + "type": "int", + "default": 500, + "min": 1, + "scope": "algorithm", + }, + { + "name": "logging_nan_inf_filter", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "epoch", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "text", "default": "2", "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "evaluation_strategy", + "type": "text", + "default": "epoch", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "eval_steps", "type": "text", "default": "500", "scope": "algorithm"}, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "gradient_checkpointing", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "peft_type", + "type": "text", + "default": "none", + "options": ["lora", "none"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/text2text/v2.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/text2text/prepack/v2.0.0/sourcedir.tar.gz", + "training_prepacked_script_version": "2.0.0", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "huggingface-training/train-huggingface-text2text-flan-t5-base.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "1024", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "2048", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + {"Name": "huggingface-text2text:eval-loss", "Regex": "'eval_loss': ([0-9\\.]+)"} + ], + "default_inference_instance_type": "ml.g5.2xlarge", + "supported_inference_instance_types": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + ], + "default_training_instance_type": "ml.p3.16xlarge", + "supported_training_instance_types": [ + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.p3dn.24xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, + }, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "training_volume_size": 512, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/genuq/dev/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-text2text-flan-t5-base", + "default_payloads": { + "Summarization": { + "content_type": "application/json", + "prompt_key": "inputs", + "body": { + "inputs": "Summarize this content - Amazon Comprehend uses natural language processing (NLP) to extract insights about the content of documents. It develops insights by recognizing the entities, key phrases, language, sentiments, and other common elements in a document. Use Amazon Comprehend to create new products based on understanding the structure of documents. For example, using Amazon Comprehend you can search social networking feeds for mentions of products or scan an entire document repository for key phrases. You can access Amazon Comprehend document analysis capabilities using the Amazon Comprehend console or using the Amazon Co", + "parameters": { + "max_new_tokens": 400, + "decoder_input_details": True, + "details": True, + }, + }, + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + } + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-3": { + "gpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-3": { + "gpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "cn-northwest-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-central-2": { + "gpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "il-central-1": { + "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "me-central-1": { + "gpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-gov-east-1": { + "gpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-gov-west-1": { + "gpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-3": { + "gpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-3": { + "gpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "cn-northwest-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "il-central-1": { + "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "me-central-1": { + "gpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-gov-east-1": { + "gpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-gov-west-1": { + "gpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, + "bedrock_console_supported": True, + "bedrock_io_mapping_id": "tgi_default_1.0.0", + }, + "deprecated_model": { + "model_id": "huggingface-text2text-flan-t5-base", + "url": "https://huggingface.co/google/flan-t5-base", + "version": "2.2.3", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "1.4.0", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", + }, + "hosting_artifact_key": "huggingface-text2text/huggingface-text2text-flan-t5-base/artifacts/inference/v2.0.0/", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v2.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-text2text/huggingface-text2text-flan-t5-base/artifacts/inference-prepack/v2.0.0/", + "hosting_prepacked_artifact_version": "2.0.0", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "accelerate==0.23.0", + "datasets==2.12.0", + "deepspeed==0.10.3", + "peft==0.5.0", + "safetensors==0.3.3", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.2", + "sagemaker_jumpstart_script_utilities==1.1.8", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + ], + "training_vulnerabilities": [], + "deprecated": True, + "hyperparameters": [ + { + "name": "epochs", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, + { + "name": "seed", + "type": "int", + "default": 42, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "batch_size", + "type": "int", + "default": 64, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.05, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "train_data_split_seed", "type": "int", "default": 0, "scope": "algorithm"}, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_eval_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_output_length", + "type": "int", + "default": 128, + "min": 0, + "scope": "algorithm", + }, + { + "name": "pad_to_max_length", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "gradient_accumulation_steps", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "load_best_model_at_end", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "early_stopping_patience", + "type": "int", + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_strategy", + "type": "text", + "default": "steps", + "options": ["no", "steps", "epoch"], + "scope": "algorithm", + }, + { + "name": "logging_first_step", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "logging_steps", + "type": "int", + "default": 500, + "min": 1, + "scope": "algorithm", + }, + { + "name": "logging_nan_inf_filter", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "epoch", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "text", "default": "2", "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "evaluation_strategy", + "type": "text", + "default": "epoch", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "eval_steps", "type": "text", "default": "500", "scope": "algorithm"}, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "gradient_checkpointing", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "peft_type", + "type": "text", + "default": "none", + "options": ["lora", "none"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/text2text/v2.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/text2text/prepack/v2.0.0/sourcedir.tar.gz", + "training_prepacked_script_version": "2.0.0", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "huggingface-training/train-huggingface-text2text-flan-t5-base.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "1024", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "2048", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + {"Name": "huggingface-text2text:eval-loss", "Regex": "'eval_loss': ([0-9\\.]+)"} + ], + "default_inference_instance_type": "ml.g5.2xlarge", + "supported_inference_instance_types": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + ], + "default_training_instance_type": "ml.p3.16xlarge", + "supported_training_instance_types": [ + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.p3dn.24xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, + }, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "training_volume_size": 512, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/genuq/dev/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-text2text-flan-t5-base", + "default_payloads": { + "Summarization": { + "content_type": "application/json", + "prompt_key": "inputs", + "body": { + "inputs": "Summarize this content - Amazon Comprehend uses natural language processing (NLP) to extract insights about the content of documents. It develops insights by recognizing the entities, key phrases, language, sentiments, and other common elements in a document. Use Amazon Comprehend to create new products based on understanding the structure of documents. For example, using Amazon Comprehend you can search social networking feeds for mentions of products or scan an entire document repository for key phrases. You can access Amazon Comprehend document analysis capabilities using the Amazon Comprehend console or using the Amazon Co", + "parameters": { + "max_new_tokens": 400, + "decoder_input_details": True, + "details": True, + }, + }, + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + } + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-3": { + "gpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-3": { + "gpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "cn-northwest-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-central-2": { + "gpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "il-central-1": { + "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "me-central-1": { + "gpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-gov-east-1": { + "gpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-gov-west-1": { + "gpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-3": { + "gpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-3": { + "gpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "cn-northwest-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "il-central-1": { + "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "me-central-1": { + "gpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-gov-east-1": { + "gpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-gov-west-1": { + "gpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, + "bedrock_console_supported": True, + "bedrock_io_mapping_id": "tgi_default_1.0.0", + }, + "vulnerable_model": { + "model_id": "huggingface-text2text-flan-t5-base", + "url": "https://huggingface.co/google/flan-t5-base", + "version": "2.2.3", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "1.4.0", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", + }, + "hosting_artifact_key": "huggingface-text2text/huggingface-text2text-flan-t5-base/artifacts/inference/v2.0.0/", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v2.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-text2text/huggingface-text2text-flan-t5-base/artifacts/inference-prepack/v2.0.0/", + "hosting_prepacked_artifact_version": "2.0.0", + "hosting_use_script_uri": False, + "inference_dependencies": [], + "training_vulnerable": True, + "training_dependencies": [ + "accelerate==0.23.0", + "datasets==2.12.0", + "deepspeed==0.10.3", + "peft==0.5.0", + "safetensors==0.3.3", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.2", + "sagemaker_jumpstart_script_utilities==1.1.8", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + ], + "inference_vulnerable": True, + "training_vulnerabilities": ["accelerate==0.23.0"], + "training_vulnerabilities": ["accelerate==0.23.0"], + "deprecated": False, + "hyperparameters": [ + { + "name": "epochs", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, + { + "name": "seed", + "type": "int", + "default": 42, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "batch_size", + "type": "int", + "default": 64, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.05, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "train_data_split_seed", "type": "int", "default": 0, "scope": "algorithm"}, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_eval_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_output_length", + "type": "int", + "default": 128, + "min": 0, + "scope": "algorithm", + }, + { + "name": "pad_to_max_length", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "gradient_accumulation_steps", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "load_best_model_at_end", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "early_stopping_patience", + "type": "int", + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_strategy", + "type": "text", + "default": "steps", + "options": ["no", "steps", "epoch"], + "scope": "algorithm", + }, + { + "name": "logging_first_step", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "logging_steps", + "type": "int", + "default": 500, + "min": 1, + "scope": "algorithm", + }, + { + "name": "logging_nan_inf_filter", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "epoch", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "text", "default": "2", "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "evaluation_strategy", + "type": "text", + "default": "epoch", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "eval_steps", "type": "text", "default": "500", "scope": "algorithm"}, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "gradient_checkpointing", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "peft_type", + "type": "text", + "default": "none", + "options": ["lora", "none"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/text2text/v2.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/text2text/prepack/v2.0.0/sourcedir.tar.gz", + "training_prepacked_script_version": "2.0.0", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "huggingface-training/train-huggingface-text2text-flan-t5-base.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "1024", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "2048", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + {"Name": "huggingface-text2text:eval-loss", "Regex": "'eval_loss': ([0-9\\.]+)"} + ], + "default_inference_instance_type": "ml.g5.2xlarge", + "supported_inference_instance_types": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + ], + "default_training_instance_type": "ml.p3.16xlarge", + "supported_training_instance_types": [ + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.p3dn.24xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, + }, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "training_volume_size": 512, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/genuq/dev/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-text2text-flan-t5-base", + "default_payloads": { + "Summarization": { + "content_type": "application/json", + "prompt_key": "inputs", + "body": { + "inputs": "Summarize this content - Amazon Comprehend uses natural language processing (NLP) to extract insights about the content of documents. It develops insights by recognizing the entities, key phrases, language, sentiments, and other common elements in a document. Use Amazon Comprehend to create new products based on understanding the structure of documents. For example, using Amazon Comprehend you can search social networking feeds for mentions of products or scan an entire document repository for key phrases. You can access Amazon Comprehend document analysis capabilities using the Amazon Comprehend console or using the Amazon Co", + "parameters": { + "max_new_tokens": 400, + "decoder_input_details": True, + "details": True, + }, + }, + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + } + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-northeast-3": { + "gpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ap-southeast-3": { + "gpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "cn-northwest-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-central-2": { + "gpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "il-central-1": { + "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "me-central-1": { + "gpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-gov-east-1": { + "gpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-gov-west-1": { + "gpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-3": { + "gpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-3": { + "gpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "cn-northwest-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "il-central-1": { + "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "me-central-1": { + "gpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-gov-east-1": { + "gpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-gov-west-1": { + "gpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, + "bedrock_console_supported": True, + "bedrock_io_mapping_id": "tgi_default_1.0.0", + }, + "js-gated-artifact-non-model-package-trainable-model": { + "model_id": "meta-textgeneration-llama-2-7b", + "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", + "version": "3.0.0", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "1.1.0", + "py_version": "py39", + }, + "training_artifact_key": "some/dummy/key", + "hosting_artifact_key": "meta-textgeneration/meta-textgeneration-llama-2-7b/artifacts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "meta-textgeneration/meta-textgen" + "eration-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", + "inference_vulnerable": False, + "inference_dependencies": [ + "sagemaker_jumpstart_huggingface_script_utilities==1.0.8", + "sagemaker_jumpstart_script_utilities==1.1.8", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "accelerate==0.21.0", + "bitsandbytes==0.39.1", + "black==23.7.0", + "brotli==1.0.9", + "datasets==2.14.1", + "fire==0.5.0", + "inflate64==0.3.1", + "loralib==0.1.1", + "multivolumefile==0.2.3", + "mypy-extensions==1.0.0", + "pathspec==0.11.1", + "peft==0.4.0", + "py7zr==0.20.5", + "pybcj==1.0.1", + "pycryptodomex==3.18.0", + "pyppmd==1.0.0", + "pytorch-triton==2.1.0+e6216047b8", + "pyzstd==0.15.9", + "safetensors==0.3.1", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.3", + "sagemaker_jumpstart_script_utilities==1.1.9", + "scipy==1.11.1", + "termcolor==2.3.0", + "texttable==1.6.7", + "tokenize-rt==5.1.0", + "tokenizers==0.13.3", + "torch==2.1.0.dev20230905+cu118", + "transformers==4.31.0", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "int8_quantization", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "enable_fsdp", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "epoch", + "type": "int", + "default": 5, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + {"name": "lora_r", "type": "int", "default": 8, "min": 1, "scope": "algorithm"}, + {"name": "lora_alpha", "type": "int", "default": 32, "min": 1, "scope": "algorithm"}, + { + "name": "lora_dropout", + "type": "float", + "default": 0.05, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "instruction_tuned", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "chat_dataset", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "add_input_output_demarcation_key", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "per_device_train_batch_size", + "type": "int", + "default": 4, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/" + "meta/transfer_learning/textgeneration/v1.0.4/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-" + "tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz", + "training_prepacked_script_version": "1.0.1", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "4095", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "4096", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + "default_inference_instance_type": "ml.g5.2xlarge", + "supported_inference_instance_types": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.12xlarge", + "supported_training_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 256, + "training_volume_size": 256, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/sec_amazon/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "meta-textgeneration-llama-2-7b", + "default_payloads": { + "meaningOfLife": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "I believe the meaning of life is", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "theoryOfRelativity": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "Simply put, the theory of relativity states that ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "teamMessage": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "A brief message congratulating the team on the launch:\n\nHi everyone,\n\nI just ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "englishToFrench": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "Translate English to French:\nsea o" + "tter => loutre de mer\npeppermint => ment" + "he poivr\u00e9e\nplush girafe => girafe peluche\ncheese =>", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "Story": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": "Please tell me a story.", + "parameters": { + "max_new_tokens": 64, + "top_p": 0.9, + "temperature": 0.2, + "decoder_input_details": True, + "details": True, + }, + }, + }, + }, + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/h" + "uggingface-pytorch-tgi-inference:2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazon" + "aws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": "meta-training/train-meta-textgeneration-llama-2-7b.tar.gz", + "environment_variables": {"SELF_DESTRUCT": "true"}, + }, + }, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "dynamic_container_deployment_supported": False, + }, + "js-gated-artifact-trainable-model": { + "model_id": "meta-textgeneration-llama-2-7b-f", + "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", + "version": "2.0.4", + "min_sdk_version": "2.174.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "djl-deepspeed", + "framework_version": "0.23.0", + "py_version": "py39", + }, + "hosting_artifact_key": "meta-infer/infer-meta-textgeneration-llama-2-7b-f.tar.gz", + "hosting_script_key": "source-directory-tarballs/meta/inference/textgeneration/v1.2.2/sourcedir.tar.gz", + "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + }, + "training_model_package_artifact_uris": { + "us-west-2": "s3://sagemaker-repository-pdx/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "us-east-2": "s3://sagemaker-repository-cmh/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "us-east-1": "s3://sagemaker-repository-iad/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "eu-west-1": "s3://sagemaker-repository-dub/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "ap-southeast-1": "s3://sagemaker-repository-sin/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + "ap-southeast-2": "s3://sagemaker-repository-syd/model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + }, + "inference_vulnerable": False, + "inference_dependencies": [ + "sagemaker_jumpstart_huggingface_script_utilities==1.0.8", + "sagemaker_jumpstart_script_utilities==1.1.8", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "accelerate==0.21.0", + "bitsandbytes==0.39.1", + "black==23.7.0", + "brotli==1.0.9", + "datasets==2.14.1", + "fire==0.5.0", + "inflate64==0.3.1", + "loralib==0.1.1", + "multivolumefile==0.2.3", + "mypy-extensions==1.0.0", + "pathspec==0.11.1", + "peft==0.4.0", + "py7zr==0.20.5", + "pybcj==1.0.1", + "pycryptodomex==3.18.0", + "pyppmd==1.0.0", + "pytorch-triton==2.1.0+6e4932cda8", + "pyzstd==0.15.9", + "safetensors==0.3.1", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.3", + "sagemaker_jumpstart_script_utilities==1.1.9", + "scipy==1.11.1", + "termcolor==2.3.0", + "texttable==1.6.7", + "tokenize-rt==5.1.0", + "tokenizers==0.13.3", + "torch==2.2.0.dev20231104+cu118", + "transformers==4.31.0", + ], + "training_vulnerabilities": [], + "deprecated": False, + "deprecate_warn_message": "For forward compatibility, pin to model_version='2.*' in your JumpStartModel or JumpStartEstimator definitions. Note that major version upgrades may have different EULA acceptance terms and input/output signatures.", + "hyperparameters": [ + { + "name": "int8_quantization", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "enable_fsdp", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "epoch", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + {"name": "lora_r", "type": "int", "default": 8, "min": 1, "scope": "algorithm"}, + {"name": "lora_alpha", "type": "int", "default": 32, "min": 1, "scope": "algorithm"}, + { + "name": "lora_dropout", + "type": "float", + "default": 0.05, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "instruction_tuned", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "chat_dataset", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "add_input_output_demarcation_key", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "per_device_train_batch_size", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/meta/transfer_learning/textgeneration/v1.0.6/sourcedir.tar.gz", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "meta-training/train-meta-textgeneration-llama-2-7b-f.tar.gz", + "inference_environment_variables": [], + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + "default_inference_instance_type": "ml.g5.2xlarge", + "supported_inference_instance_types": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.12xlarge", + "supported_training_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, + }, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 256, + "training_volume_size": 256, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/oasst_top/train/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "meta-textgeneration-llama-2-7b-f", + "default_payloads": { + "Mayo": { + "content_type": "application/json", + "body": { + "inputs": [[{"role": "user", "content": "what is the recipe of mayonnaise?"}]], + "parameters": {"max_new_tokens": 512, "top_p": 0.9, "temperature": 0.6}, + }, + } + }, + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "ap-east-1": { + "alias_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "ap-northeast-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "ap-northeast-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "ap-south-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "ap-southeast-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "ap-southeast-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "ca-central-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "cn-north-1": { + "alias_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "eu-central-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "eu-north-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "eu-south-1": { + "alias_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "eu-west-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "eu-west-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "eu-west-3": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "me-south-1": { + "alias_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "sa-east-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "us-east-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "us-east-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "us-west-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + "us-west-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118" + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "dynamic_container_deployment_supported": False, + }, + "js-trainable-model": { + "model_id": "autogluon-classification-ensemble", + "url": "https://auto.gluon.ai/stable/index.html", + "version": "1.1.1", + "min_sdk_version": "2.103.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "autogluon", + "framework_version": "0.4.3", + "py_version": "py38", + }, + "hosting_artifact_key": "autogluon-infer/v1.1.0/infer-autogluon-classification-ensemble.tar.gz", + "hosting_script_key": "source-directory-tarballs/autogluon/inference/classification/v1.0.0/sourcedir.tar.gz", + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": ["sagemaker_jumpstart_script_utilities==1.0.1"], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + {"name": "eval_metric", "type": "text", "default": "auto", "scope": "algorithm"}, + { + "name": "presets", + "type": "text", + "default": "medium_quality", + "options": [ + "best_quality", + "high_quality", + "good_quality", + "medium_quality", + "optimize_for_deployment", + "interpretable", + ], + "scope": "algorithm", + }, + { + "name": "auto_stack", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "num_bag_folds", + "type": "text", + "default": "0", + "options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "scope": "algorithm", + }, + {"name": "num_bag_sets", "type": "int", "default": 1, "min": 1, "scope": "algorithm"}, + { + "name": "num_stack_levels", + "type": "int", + "default": 0, + "min": 0, + "max": 3, + "scope": "algorithm", + }, + { + "name": "refit_full", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "set_best_to_refit_full", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_space", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "verbosity", + "type": "int", + "default": 2, + "min": 0, + "max": 4, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/autogluon/transfer_learning/classification/" + "v1.0.2/sourcedir.tar.gz", + "training_ecr_specs": { + "framework": "autogluon", + "framework_version": "0.4.3", + "py_version": "py38", + }, + "training_artifact_key": "autogluon-training/train-autogluon-classification-ensemble.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": [ + "ml.p2.xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + ], + "default_training_instance_type": "ml.p3.2xlarge", + "supported_training_instance_types": [ + "ml.m5.xlarge", + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + "ml.p2.xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "estimator_kwargs": {"encrypt_inter_container_traffic": True}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["text/csv"], + "supported_accept_types": ["application/json"], + "default_content_type": "text/csv", + "default_accept_type": "application/json", + }, + "resource_name_base": "blahblahblah", + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1" + ".amazonaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1" + ".amazonaws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1." + "amazonaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1." + "amazonaws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "ap-northeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-" + "1.amazonaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-" + "1.amazonaws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2" + ".amazonaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2" + ".amazonaws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "ap-south-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.ama" + "zonaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazo" + "naws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "ap-southeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazo" + "naws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.a" + "mazonaws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "ap-southeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.am" + "azonaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazon" + "aws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amaz" + "onaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazon" + "aws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaw" + "s.com.cn/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws" + ".com.cn/autogluon-inference:0.4.3-gpu-py38", + }, + "eu-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazona" + "ws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaw" + "s.com/autogluon-inference:0.4.3-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazona" + "ws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazona" + "ws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amaz" + "onaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazo" + "naws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "eu-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1." + "amazonaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.ama" + "zonaws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazona" + "ws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaw" + "s.com/autogluon-inference:0.4.3-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws" + ".com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amaz" + "onaws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amaz" + "onaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazo" + "naws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "sa-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amaz" + "onaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.ama" + "zonaws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "us-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.ama" + "zonaws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.am" + "azonaws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "us-east-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazo" + "naws.com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazon" + "aws.com/autogluon-inference:0.4.3-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws." + "com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws." + "com/autogluon-inference:0.4.3-gpu-py38", + }, + "us-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws" + ".com/autogluon-inference:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws." + "com/autogluon-inference:0.4.3-gpu-py38", + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south" + "-1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amaz" + "naws.com/autogluon-training:0.4.3-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-" + "1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amaz" + "onaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "ap-northeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.ama" + "zonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.a" + "mazonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2." + "amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2." + "amazonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "ap-south-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-sou" + "th-1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.ama" + "zonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "ap-southeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-" + "1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast" + "-1.amazonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "ap-southeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast" + "-2.amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2." + "amazonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.a" + "mazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.am" + "azonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "eu-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.a" + "mazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.am" + "azonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.am" + "azonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazon" + "aws.com/autogluon-training:0.4.3-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amaz" + "onaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.ama" + "zonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "eu-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amaz" + "onaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazo" + "naws.com/autogluon-training:0.4.3-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.am" + "azonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.am" + "azonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3." + "amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3." + "amazonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-sout" + "h-1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1." + "amazonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "sa-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1." + "amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1" + ".amazonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "us-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1." + "amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1." + "amazonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "us-east-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2" + ".amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-" + "2.amazonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west" + "-1.amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-wes" + "t-1.amazonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + "us-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west" + "-2.amazonaws.com/autogluon-training:0.4.3-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-w" + "est-2.amazonaws.com/autogluon-training:0.4.3-gpu-py38", + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + }, + }, + }, + "response-keys": { + "model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16", + "url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth", + "version": "1.0.0", + "min_sdk_version": "2.144.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "djl-deepspeed", + "framework_version": "0.21.0", + "py_version": "py38", + "huggingface_transformers_version": "4.17", + }, + "hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st" + "able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", + "hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/" + "infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.0", + "inference_vulnerable": False, + "inference_dependencies": [ + "accelerate==0.18.0", + "diffusers==0.14.0", + "fsspec==2023.4.0", + "huggingface-hub==0.14.1", + "transformers==4.26.1", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.g5.8xlarge", + "supported_inference_instance_types": [ + "ml.g5.8xlarge", + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.16xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.g4dn.2xlarge", + "ml.g4dn.4xlarge", + "ml.g4dn.8xlarge", + "ml.g4dn.16xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "sd-1-5-controlnet-1-1-fp16", + "default_payloads": { + "Dog": { + "content_type": "application/json", + "prompt_key": "hello.prompt", + "body": { + "hello": {"prompt": "a dog"}, + "seed": 43, + }, + } + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d" + "jl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + }, + }, + }, + "default_payloads": { + "model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16", + "url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth", + "version": "2.0.5", + "min_sdk_version": "2.189.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "djl-deepspeed", + "framework_version": "0.21.0", + "py_version": "py38", + "huggingface_transformers_version": "4.17", + }, + "hosting_artifact_key": "stabilityai-depth2img/model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16/artifacts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "stabilityai-depth2img/model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [ + "accelerate==0.18.0", + "diffusers==0.14.0", + "fsspec==2023.4.0", + "huggingface-hub==0.14.1", + "transformers==4.26.1", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "default_payloads": { + "Dog": { + "content_type": "application/json", + "body": { + "prompt": "a dog", + "num_images_per_prompt": 2, + "num_inference_steps": 20, + "guidance_scale": 7.5, + "seed": 43, + "eta": 0.7, + "image": "$s3_b64", + }, + } + }, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.g5.8xlarge", + "supported_inference_instance_types": [ + "ml.g5.8xlarge", + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.16xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.g4dn.2xlarge", + "ml.g4dn.4xlarge", + "ml.g4dn.8xlarge", + "ml.g4dn.16xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "sd-1-5-controlnet-1-1-fp16", + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "ap-east-1": { + "alias_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "ap-northeast-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "ap-northeast-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "ap-northeast-3": { + "alias_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "ap-south-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "ap-southeast-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "ap-southeast-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "ap-southeast-3": { + "alias_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "ca-central-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "cn-north-1": { + "alias_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "cn-northwest-1": { + "alias_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "eu-central-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "eu-north-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "eu-south-1": { + "alias_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "eu-west-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "eu-west-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "eu-west-3": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "il-central-1": { + "alias_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "me-south-1": { + "alias_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "sa-east-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "us-east-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "us-east-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "us-gov-east-1": { + "alias_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "us-gov-west-1": { + "alias_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "us-west-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + "us-west-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g6": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g6e": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + }, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, + }, + "prompt-key": { + "model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16", + "url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth", + "version": "1.0.0", + "min_sdk_version": "2.144.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "djl-deepspeed", + "framework_version": "0.21.0", + "py_version": "py38", + "huggingface_transformers_version": "4.17", + }, + "hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st" + "able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", + "hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/" + "infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.0", + "inference_vulnerable": False, + "inference_dependencies": [ + "accelerate==0.18.0", + "diffusers==0.14.0", + "fsspec==2023.4.0", + "huggingface-hub==0.14.1", + "transformers==4.26.1", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.g5.8xlarge", + "supported_inference_instance_types": [ + "ml.g5.8xlarge", + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.16xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.g4dn.2xlarge", + "ml.g4dn.4xlarge", + "ml.g4dn.8xlarge", + "ml.g4dn.16xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "sd-1-5-controlnet-1-1-fp16", + "default_payloads": { + "Dog": { + "content_type": "application/json", + "prompt_key": "hello.prompt", + "body": { + "hello": {"prompt": "a dog"}, + "seed": 43, + }, + } + }, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d" + "jl-inference:0.21.0-deepspeed0.8.3-cu117" + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + }, + }, + }, + "predictor-specs-model": { + "model_id": "huggingface-text2text-flan-t5-xxl-fp16", + "url": "https://huggingface.co/google/flan-t5-xxl", + "version": "1.0.1", + "min_sdk_version": "2.130.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.12.0", + "py_version": "py38", + "huggingface_transformers_version": "4.17.0", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.3/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack-huggingface-" + "text2text-flan-t5-xxl-fp16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.1", + "inference_vulnerable": False, + "inference_dependencies": [ + "accelerate==0.16.0", + "bitsandbytes==0.37.0", + "filelock==3.9.0", + "huggingface_hub==0.12.0", + "regex==2022.7.9", + "tokenizers==0.13.2", + "transformers==4.26.0", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + }, + {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "text", + "default": "1", + "scope": "container", + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.g5.12xlarge", + "supported_inference_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.g4dn.12xlarge", + ], + "predictor_specs": { + "supported_content_types": ["application/x-text"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "application/x-text", + "default_accept_type": "application/json", + }, + }, + "model_data_s3_prefix_model": { + "model_id": "huggingface-text2text-flan-t5-xxl-fp16", + "url": "https://huggingface.co/google/flan-t5-xxl", + "version": "1.1.2", + "min_sdk_version": "2.144.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface", + "framework_version": "1.13.1", + "py_version": "py39", + "huggingface_transformers_version": "4.26.0", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.1.2/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.1.2/infer-prepack-huggingface-text2text-flan-t5-xxl-fp16.tar.gz", + "hosting_prepacked_artifact_version": "1.1.2", + "inference_vulnerable": False, + "inference_dependencies": ["accelerate==0.19.0", "bitsandbytes==0.38.1", "peft==0.3.0"], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "accelerate==0.19.0", + "datasets==2.12.0", + "deepspeed==0.9.2", + "peft==0.3.0", + "sagemaker_jumpstart_huggingface_script_utilities==1.0.2", + "sagemaker_jumpstart_script_utilities==1.1.4", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "epochs", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, + { + "name": "seed", + "type": "int", + "default": 42, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "batch_size", + "type": "int", + "default": 64, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.05, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "train_data_split_seed", "type": "int", "default": 0, "scope": "algorithm"}, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_eval_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_output_length", + "type": "int", + "default": 128, + "min": 0, + "scope": "algorithm", + }, + { + "name": "pad_to_max_length", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "gradient_accumulation_steps", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "load_best_model_at_end", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "early_stopping_patience", + "type": "int", + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_strategy", + "type": "text", + "default": "steps", + "options": ["no", "steps", "epoch"], + "scope": "algorithm", + }, + { + "name": "logging_first_step", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "logging_steps", + "type": "int", + "default": 500, + "min": 1, + "scope": "algorithm", + }, + { + "name": "logging_nan_inf_filter", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "epoch", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "text", "default": "2", "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "evalaution_strategy", + "type": "text", + "default": "epoch", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "eval_steps", "type": "text", "default": "500", "scope": "algorithm"}, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "gradient_checkpointing", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "peft_type", + "type": "text", + "default": "lora", + "options": ["lora"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/text2text/v1.2.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/text2text/prepack/v1.1.2/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.2", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "1.13.1", + "py_version": "py39", + "huggingface_transformers_version": "4.26.0", + }, + "training_artifact_key": "huggingface-training/train-huggingface-text2text-flan-t5-xxl-fp16.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "TS_DEFAULT_WORKERS_PER_MODEL", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + {"Name": "huggingface-text2text:eval-loss", "Regex": "'eval_loss': ([0-9\\.]+)"} + ], + "default_inference_instance_type": "ml.g5.12xlarge", + "supported_inference_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.g4dn.12xlarge", + ], + "default_training_instance_type": "ml.g5.24xlarge", + "supported_training_instance_types": ["ml.g5.24xlarge", "ml.g5.48xlarge"], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, + }, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/x-text", "application/json"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "application/x-text", + "default_accept_type": "application/json", + }, + "inference_volume_size": 256, + "training_volume_size": 256, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/genuq/dev/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-text2text-flan-t5-xxl-fp16", + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "ap-northeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "ap-south-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "ap-southeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "ap-southeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "eu-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "eu-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "sa-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "us-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "us-east-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + "us-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-cpu-py39-ubuntu20.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "dynamic_container_deployment_supported": False, + }, + "no-supported-instance-types-model": { + "model_id": "pytorch-ic-mobilenet-v2", + "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", + "version": "1.0.0", + "min_sdk_version": "2.49.0", + "training_supported": True, + "incremental_training_supported": True, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "hyperparameters": [ + { + "name": "epochs", + "type": "int", + "default": 3, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "adam-learning-rate", + "type": "float", + "default": 0.05, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "batch-size", + "type": "int", + "default": 4, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + }, + {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "text", + "default": "1", + "scope": "container", + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + }, + ], + "default_inference_instance_type": "", + "supported_inference_instance_types": None, + "default_training_instance_type": None, + "supported_training_instance_types": [], + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "metrics": [], + }, + "huggingface-text2text-flan-t5-xxl-fp16": { + "model_id": "huggingface-text2text-flan-t5-xxl-fp16", + "url": "https://huggingface.co/google/flan-t5-xxl", + "version": "1.0.0", + "min_sdk_version": "2.130.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.12.0", + "py_version": "py38", + "huggingface_transformers_version": "4.17.0", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.2/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-" + "text2text-flan-t5-xxl-fp16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.0", + "inference_vulnerable": False, + "inference_dependencies": [ + "accelerate==0.16.0", + "bitsandbytes==0.37.0", + "filelock==3.9.0", + "huggingface-hub==0.12.0", + "regex==2022.7.9", + "tokenizers==0.13.2", + "transformers==4.26.0", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + }, + {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "text", + "default": "1", + "scope": "container", + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + }, + ], + "inference_vulnerable": False, + "training_vulnerable": False, + "deprecated": False, + "default_training_instance_type": None, + "supported_training_instance_types": [], + "metrics": [], + "default_inference_instance_type": "ml.g5.12xlarge", + "supported_inference_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.g4dn.12xlarge", + ], + }, + "mock-model-training-prepacked-script-key": { + "model_id": "sklearn-classification-linear", + "url": "https://scikit-learn.org/stable/", + "version": "1.0.0", + "min_sdk_version": "2.68.1", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "sklearn", + "framework_version": "0.23-1", + "py_version": "py3", + }, + "hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz", + "hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.0.0/sourcedir.tar.gz", + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "tol", + "type": "float", + "default": 0.0001, + "min": 1e-20, + "max": 50, + "scope": "algorithm", + }, + { + "name": "penalty", + "type": "text", + "default": "l2", + "options": ["l1", "l2", "elasticnet", "none"], + "scope": "algorithm", + }, + { + "name": "alpha", + "type": "float", + "default": 0.0001, + "min": 1e-20, + "max": 999, + "scope": "algorithm", + }, + { + "name": "l1_ratio", + "type": "float", + "default": 0.15, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/sklearn/transfer_learning/classification/" + "v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": "some/key/to/training_prepacked_script_key.tar.gz", + "training_ecr_specs": { + "framework_version": "0.23-1", + "framework": "sklearn", + "py_version": "py3", + }, + "training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + }, + {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "text", + "default": "1", + "scope": "container", + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + }, + ], + }, +} + + +PROTOTYPICAL_MODEL_SPECS_DICT = { + "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1": { + "model_id": "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", + "url": "https://tfhub.dev/google/bit/m-r101x1/ilsvrc2012_classification/1", + "version": "4.0.6", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": True, + "hosting_ecr_specs": { + "framework": "tensorflow", + "framework_version": "2.8", + "py_version": "py39", + }, + "hosting_artifact_key": "tensorflow-ic/tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1/artifacts/inference/v3.0.0/", + "hosting_script_key": "source-directory-tarballs/tensorflow/inference/ic/v2.0.3/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "tensorflow-ic/tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "sagemaker_jumpstart_prepack_script_utilities==1.0.0", + "sagemaker_jumpstart_script_utilities==1.1.1", + "sagemaker_jumpstart_tensorflow_script_utilities==1.0.1", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "train_only_top_layer", + "type": "text", + "default": "True", + "options": ["False", "True"], + "scope": "algorithm", + }, + { + "name": "epochs", + "type": "int", + "default": 5, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "batch_size", + "type": "int", + "default": 32, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "optimizer", + "type": "text", + "default": "adam", + "options": ["adam", "sgd", "nesterov", "rmsprop", "adagrad", "adadelta"], + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.001, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "beta_1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "beta_2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "momentum", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "epsilon", + "type": "float", + "default": 1e-07, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "rho", + "type": "float", + "default": 0.95, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "initial_accumulator_value", + "type": "float", + "default": 0.1, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "reinitialize_top_layer", + "type": "text", + "default": "Auto", + "options": ["Auto", "True", "False"], + "scope": "algorithm", + }, + { + "name": "early_stopping", + "type": "text", + "default": "False", + "options": ["False", "True"], + "scope": "algorithm", + }, + { + "name": "early_stopping_patience", + "type": "int", + "default": 5, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "early_stopping_min_delta", + "type": "float", + "default": 0.0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "dropout_rate", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "regularizers_l2", + "type": "float", + "default": 0.0001, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "label_smoothing", + "type": "float", + "default": 0.1, + "min": 0, + "max": 1, + "scope": "algorithm", }, - }, - }, - "response-keys": { - "model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16", - "url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth", - "version": "1.0.0", - "min_sdk_version": "2.144.0", - "training_supported": False, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "djl-deepspeed", - "framework_version": "0.21.0", - "py_version": "py38", - "huggingface_transformers_version": "4.17", - }, - "hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st" - "able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", - "hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/" - "infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.0", - "inference_vulnerable": False, - "inference_dependencies": [ - "accelerate==0.18.0", - "diffusers==0.14.0", - "fsspec==2023.4.0", - "huggingface-hub==0.14.1", - "transformers==4.26.1", - ], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "inference_environment_variables": [ { - "name": "SAGEMAKER_PROGRAM", + "name": "image_resize_interpolation", "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, + "default": "bilinear", + "options": [ + "bilinear", + "nearest", + "bicubic", + "area", + "lanczos3", + "lanczos5", + "gaussian", + "mitchellcubic", + ], + "scope": "algorithm", }, { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "name": "augmentation", "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + "default": "False", + "options": ["False", "True"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "name": "augmentation_random_flip", "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "default": "horizontal_and_vertical", + "options": ["horizontal_and_vertical", "horizontal", "vertical", "None"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "name": "augmentation_random_rotation", + "type": "float", + "default": 0.2, + "min": -1, + "max": 1, + "scope": "algorithm", + }, + { + "name": "augmentation_random_zoom", + "type": "float", + "default": 0.1, + "min": -1, + "max": 1, + "scope": "algorithm", + }, + { + "name": "binary_mode", "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, + "default": "False", + "options": ["False", "True"], + "scope": "algorithm", }, { - "name": "ENDPOINT_SERVER_TIMEOUT", - "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, + "name": "eval_metric", + "type": "text", + "default": "accuracy", + "options": ["accuracy", "precision", "recall", "auc", "prc"], + "scope": "algorithm", }, { - "name": "MODEL_CACHE_ROOT", + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + {"name": "random_seed", "type": "int", "default": 123, "min": 0, "scope": "algorithm"}, + { + "name": "sagemaker_submit_directory", "type": "text", - "default": "/opt/ml/model", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", - "required_for_model_class": True, }, { - "name": "SAGEMAKER_ENV", + "name": "sagemaker_program", "type": "text", - "default": "1", + "default": "transfer_learning.py", "scope": "container", - "required_for_model_class": True, }, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", "scope": "container", - "required_for_model_class": True, }, ], - "metrics": [], - "default_inference_instance_type": "ml.g5.8xlarge", - "supported_inference_instance_types": [ - "ml.g5.8xlarge", - "ml.g5.xlarge", - "ml.g5.2xlarge", - "ml.g5.4xlarge", - "ml.g5.16xlarge", - "ml.p3.2xlarge", - "ml.g4dn.xlarge", - "ml.g4dn.2xlarge", - "ml.g4dn.4xlarge", - "ml.g4dn.8xlarge", - "ml.g4dn.16xlarge", - ], - "model_kwargs": {}, - "deploy_kwargs": {}, - "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", - "default_accept_type": "application/json", - }, - "inference_enable_network_isolation": True, - "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "sd-1-5-controlnet-1-1-fp16", - "default_payloads": { - "Dog": { - "content_type": "application/json", - "prompt_key": "hello.prompt", - "body": { - "hello": {"prompt": "a dog"}, - "seed": 43, - }, - } - }, - "hosting_instance_type_variants": { - "regional_aliases": { - "af-south-1": { - "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d" - "jl-inference:0.21.0-deepspeed0.8.3-cu117" - }, - }, - "variants": { - "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - }, - }, - }, - "default_payloads": { - "model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16", - "url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth", - "version": "1.0.0", - "min_sdk_version": "2.144.0", - "training_supported": False, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "djl-deepspeed", - "framework_version": "0.21.0", - "py_version": "py38", - "huggingface_transformers_version": "4.17", + "training_script_key": "source-directory-tarballs/tensorflow/transfer_learning/ic/v2.1.2/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/tensorflow/transfer_learning/ic/prepack/v1.1.2/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.2", + "training_ecr_specs": { + "framework": "tensorflow", + "framework_version": "2.9", + "py_version": "py39", }, - "hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st" - "able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", - "hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/" - "infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.0", - "inference_vulnerable": False, - "inference_dependencies": [ - "accelerate==0.18.0", - "diffusers==0.14.0", - "fsspec==2023.4.0", - "huggingface-hub==0.14.1", - "transformers==4.26.1", - ], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, + "training_artifact_key": "tensorflow-training/v3.0.0/train-tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1.tar.gz", "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", @@ -5516,114 +11176,473 @@ "required_for_model_class": True, }, ], - "metrics": [], - "default_inference_instance_type": "ml.g5.8xlarge", - "supported_inference_instance_types": [ - "ml.g5.8xlarge", - "ml.g5.xlarge", - "ml.g5.2xlarge", - "ml.g5.4xlarge", - "ml.g5.16xlarge", + "metrics": [{"Name": "tflow-ic:val-accuracy", "Regex": "val_accuracy: ([0-9\\.]+)"}], + "default_inference_instance_type": "ml.p3.2xlarge", + "supported_inference_instance_types": [ + "ml.p3.2xlarge", + "ml.p2.xlarge", + "ml.g4dn.xlarge", + "ml.m5.xlarge", + "ml.m4.xlarge", + "ml.m5.large", + "ml.c5.2xlarge", + "ml.c5.xlarge", + "ml.r5.xlarge", + "ml.r5.large", + "ml.c6i.xlarge", + "ml.c6i.large", + ], + "default_training_instance_type": "ml.p3.2xlarge", + "supported_training_instance_types": [ "ml.p3.2xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.p2.xlarge", + "ml.p2.8xlarge", + "ml.p2.16xlarge", + "ml.g5.xlarge", "ml.g4dn.xlarge", "ml.g4dn.2xlarge", "ml.g4dn.4xlarge", "ml.g4dn.8xlarge", "ml.g4dn.16xlarge", + "ml.m5.xlarge", + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", ], "model_kwargs": {}, "deploy_kwargs": {}, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", + "supported_content_types": ["application/x-image"], + "supported_accept_types": ["application/json", "application/json;verbose"], + "default_content_type": "application/x-image", "default_accept_type": "application/json", }, "inference_enable_network_isolation": True, - "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "sd-1-5-controlnet-1-1-fp16", - "default_payloads": { - "Dog": { - "content_type": "application/json", - "body": { - "prompt": "a dog", - "num_images_per_prompt": 2, - "num_inference_steps": 20, - "guidance_scale": 7.5, - "seed": 43, - "eta": 0.7, - "image": "$s3_b64", + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/tf_flowers/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "bit-m-r101x1-ilsvrc2012-classification", + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/tensorflow-inference:2.8-gpu", }, - } + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "ap-northeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "ap-south-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "ap-southeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "ap-southeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/tensorflow-inference:2.8-gpu", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/tensorflow-inference:2.8-gpu", + }, + "eu-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "eu-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "sa-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "us-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "us-east-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + "us-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:2.8-cpu", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:2.8-gpu", + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + }, }, - "hosting_instance_type_variants": { + "training_instance_type_variants": { "regional_aliases": { "af-south-1": { - "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d" - "jl-inference:0.21.0-deepspeed0.8.3-cu117" + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "ap-northeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "ap-south-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "ap-southeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "ap-southeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/tensorflow-training:2.9-gpu-py39", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/tensorflow-training:2.9-gpu-py39", + }, + "eu-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "eu-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "sa-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "us-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "us-east-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/tensorflow-training:2.9-gpu-py39", + }, + "us-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.9-cpu-py39", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:2.9-gpu-py39", }, }, "variants": { - "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, }, }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, }, - "prompt-key": { - "model_id": "model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16", - "url": "https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth", - "version": "1.0.0", - "min_sdk_version": "2.144.0", - "training_supported": False, - "incremental_training_supported": False, + "pytorch-ic-mobilenet-v2": { + "model_id": "pytorch-ic-mobilenet-v2", + "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", + "version": "3.0.6", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": True, "hosting_ecr_specs": { - "framework": "djl-deepspeed", - "framework_version": "0.21.0", + "framework": "pytorch", + "framework_version": "1.10.0", "py_version": "py38", - "huggingface_transformers_version": "4.17", }, - "hosting_artifact_key": "stabilityai-infer/infer-model-depth2img-st" - "able-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", - "hosting_script_key": "source-directory-tarballs/stabilityai/inference/depth2img/v1.0.0/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "stabilityai-infer/prepack/v1.0.0/" - "infer-prepack-model-depth2img-stable-diffusion-v1-5-controlnet-v1-1-fp16.tar.gz", + "hosting_artifact_key": "pytorch-ic/pytorch-ic-mobilenet-v2/artifacts/inference/v2.0.0/", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v2.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "pytorch-ic/pytorch-ic-mobilenet-v2/artifacts/inference-prepack/v1.0.0/", "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, "inference_vulnerable": False, - "inference_dependencies": [ - "accelerate==0.18.0", - "diffusers==0.14.0", - "fsspec==2023.4.0", - "huggingface-hub==0.14.1", - "transformers==4.26.1", - ], + "inference_dependencies": [], "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": [], + "training_dependencies": ["sagemaker_jumpstart_prepack_script_utilities==1.0.0"], "training_vulnerabilities": [], "deprecated": False, + "hyperparameters": [ + { + "name": "train_only_top_layer", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "epochs", + "type": "int", + "default": 5, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "batch_size", + "type": "int", + "default": 4, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "reinitialize_top_layer", + "type": "text", + "default": "Auto", + "options": ["Auto", "True", "False"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v2.3.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/prepack/v1.1.0/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.0", + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.10.0", + "py_version": "py38", + }, + "training_artifact_key": "pytorch-training/v2.0.0/train-pytorch-ic-mobilenet-v2.tar.gz", "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", @@ -5682,279 +11701,411 @@ "required_for_model_class": True, }, ], - "metrics": [], - "default_inference_instance_type": "ml.g5.8xlarge", + "metrics": [{"Name": "pytorch-ic:val-accuracy", "Regex": "val_accuracy: ([0-9\\.]+)"}], + "default_inference_instance_type": "ml.m5.large", "supported_inference_instance_types": [ - "ml.g5.8xlarge", - "ml.g5.xlarge", - "ml.g5.2xlarge", - "ml.g5.4xlarge", - "ml.g5.16xlarge", - "ml.p3.2xlarge", - "ml.g4dn.xlarge", - "ml.g4dn.2xlarge", - "ml.g4dn.4xlarge", - "ml.g4dn.8xlarge", - "ml.g4dn.16xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + "ml.m4.large", + "ml.m4.xlarge", ], + "default_training_instance_type": "ml.m5.xlarge", + "supported_training_instance_types": ["ml.m5.xlarge", "ml.c5.2xlarge", "ml.m4.xlarge"], "model_kwargs": {}, "deploy_kwargs": {}, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", + "supported_content_types": ["application/x-image"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "application/x-image", "default_accept_type": "application/json", }, "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/tf_flowers/", "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "sd-1-5-controlnet-1-1-fp16", - "default_payloads": { - "Dog": { - "content_type": "application/json", - "prompt_key": "hello.prompt", - "body": { - "hello": {"prompt": "a dog"}, - "seed": 43, - }, - } - }, + "fine_tuning_supported": True, + "resource_name_base": "pt-ic-mobilenet-v2", "hosting_instance_type_variants": { "regional_aliases": { "af-south-1": { - "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/d" - "jl-inference:0.21.0-deepspeed0.8.3-cu117" + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-northeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-south-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-inference:1.10.0-gpu-py38", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-central-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-west-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "sa-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-east-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-west-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", }, }, "variants": { - "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, - }, - }, - }, - "predictor-specs-model": { - "model_id": "huggingface-text2text-flan-t5-xxl-fp16", - "url": "https://huggingface.co/google/flan-t5-xxl", - "version": "1.0.1", - "min_sdk_version": "2.130.0", - "training_supported": False, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.12.0", - "py_version": "py38", - "huggingface_transformers_version": "4.17.0", - }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.3/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack-huggingface-" - "text2text-flan-t5-xxl-fp16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.1", - "inference_vulnerable": False, - "inference_dependencies": [ - "accelerate==0.16.0", - "bitsandbytes==0.37.0", - "filelock==3.9.0", - "huggingface_hub==0.12.0", - "regex==2022.7.9", - "tokenizers==0.13.2", - "transformers==4.26.0", - ], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "text", - "default": "1", - "scope": "container", + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-northeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-south-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-training:1.10.0-gpu-py38", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.10.0-gpu-py38", + }, + "eu-central-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-west-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "sa-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-east-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-west-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, }, - ], - "metrics": [], - "default_inference_instance_type": "ml.g5.12xlarge", - "supported_inference_instance_types": [ - "ml.g5.12xlarge", - "ml.g5.24xlarge", - "ml.p3.8xlarge", - "ml.p3.16xlarge", - "ml.g4dn.12xlarge", - ], - "predictor_specs": { - "supported_content_types": ["application/x-text"], - "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-text", - "default_accept_type": "application/json", }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, }, - "model_data_s3_prefix_model": { - "model_id": "huggingface-text2text-flan-t5-xxl-fp16", - "url": "https://huggingface.co/google/flan-t5-xxl", - "version": "1.0.1", - "min_sdk_version": "2.130.0", - "training_supported": False, - "incremental_training_supported": False, + "mxnet-semseg-fcn-resnet50-ade": { + "model_id": "mxnet-semseg-fcn-resnet50-ade", + "url": "https://cv.gluon.ai/model_zoo/segmentation.html", + "version": "2.0.3", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": True, "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.12.0", + "framework": "mxnet", + "framework_version": "1.9.0", "py_version": "py38", - "huggingface_transformers_version": "4.17.0", }, - "hosting_artifact_key": "huggingface-infer/", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.3/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/", - "hosting_prepacked_artifact_version": "1.0.1", + "hosting_artifact_key": "mxnet-semseg/mxnet-semseg-fcn-resnet50-ade/artifacts/inference/v1.1.0/", + "hosting_script_key": "source-directory-tarballs/mxnet/inference/semseg/v1.2.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "mxnet-semseg/mxnet-semseg-fcn-resnet50-ade/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, "inference_vulnerable": False, - "inference_dependencies": [ - "accelerate==0.16.0", - "bitsandbytes==0.37.0", - "filelock==3.9.0", - "huggingface_hub==0.12.0", - "regex==2022.7.9", - "tokenizers==0.13.2", - "transformers==4.26.0", - ], + "inference_dependencies": [], "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": [], + "training_dependencies": [ + "numpy==1.23.1", + "opencv_python==4.7.0.68", + "sagemaker_jumpstart_prepack_script_utilities==1.0.0", + ], "training_vulnerabilities": [], "deprecated": False, - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "text", - "default": "1", - "scope": "container", - }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", - }, - ], - "metrics": [], - "default_inference_instance_type": "ml.g5.12xlarge", - "supported_inference_instance_types": [ - "ml.g5.12xlarge", - "ml.g5.24xlarge", - "ml.p3.8xlarge", - "ml.p3.16xlarge", - "ml.g4dn.12xlarge", - ], - "predictor_specs": { - "supported_content_types": ["application/x-text"], - "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-text", - "default_accept_type": "application/json", - }, - }, - "no-supported-instance-types-model": { - "model_id": "pytorch-ic-mobilenet-v2", - "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", - "version": "1.0.0", - "min_sdk_version": "2.49.0", - "training_supported": True, - "incremental_training_supported": True, - "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", - }, - "training_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", - }, - "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", - "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", - "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", "hyperparameters": [ { "name": "epochs", "type": "int", - "default": 3, + "default": 5, "min": 1, "max": 1000, "scope": "algorithm", @@ -5962,7 +12113,7 @@ { "name": "adam-learning-rate", "type": "float", - "default": 0.05, + "default": 0.001, "min": 1e-08, "max": 1, "scope": "algorithm", @@ -5970,220 +12121,517 @@ { "name": "batch-size", "type": "int", - "default": 4, + "default": 2, "min": 1, "max": 1024, - "scope": "algorithm", - }, - { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", - }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", - }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", - }, - ], - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "name": "train-only-top-layer", "type": "text", - "default": "20", - "scope": "container", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "MODEL_CACHE_ROOT", + "name": "sagemaker_submit_directory", "type": "text", - "default": "/opt/ml/model", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "name": "sagemaker_program", "type": "text", - "default": "1", + "default": "transfer_learning.py", "scope": "container", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "name": "sagemaker_container_log_level", "type": "text", - "default": "3600", + "default": "20", "scope": "container", }, ], - "default_inference_instance_type": "", - "supported_inference_instance_types": None, - "default_training_instance_type": None, - "supported_training_instance_types": [], - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "metrics": [], - }, - "huggingface-text2text-flan-t5-xxl-fp16": { - "model_id": "huggingface-text2text-flan-t5-xxl-fp16", - "url": "https://huggingface.co/google/flan-t5-xxl", - "version": "1.0.0", - "min_sdk_version": "2.130.0", - "training_supported": False, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.12.0", + "training_script_key": "source-directory-tarballs/mxnet/transfer_learning/semseg/v1.5.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/mxnet/transfer_learning/semseg/prepack/v1.1.0/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.0", + "training_ecr_specs": { + "framework": "mxnet", + "framework_version": "1.9.0", "py_version": "py38", - "huggingface_transformers_version": "4.17.0", }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.2/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-" - "text2text-flan-t5-xxl-fp16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.0", - "inference_vulnerable": False, - "inference_dependencies": [ - "accelerate==0.16.0", - "bitsandbytes==0.37.0", - "filelock==3.9.0", - "huggingface-hub==0.12.0", - "regex==2022.7.9", - "tokenizers==0.13.2", - "transformers==4.26.0", - ], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, + "training_artifact_key": "mxnet-training/train-mxnet-semseg-fcn-resnet50-ade.tar.gz", "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", "type": "text", "default": "inference.py", "scope": "container", + "required_for_model_class": True, }, { "name": "SAGEMAKER_SUBMIT_DIRECTORY", "type": "text", "default": "/opt/ml/model/code", "scope": "container", + "required_for_model_class": False, }, { "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", "type": "text", "default": "20", "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, }, { "name": "MODEL_CACHE_ROOT", "type": "text", "default": "/opt/ml/model", "scope": "container", + "required_for_model_class": True, }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [{"Name": "mxnet-semseg:val-loss", "Regex": "validation loss=([0-9\\.]+)"}], + "default_inference_instance_type": "ml.p3.2xlarge", + "supported_inference_instance_types": [ + "ml.p2.xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + ], + "default_training_instance_type": "ml.p3.2xlarge", + "supported_training_instance_types": [ + "ml.p3.2xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.g4dn.8xlarge", + "ml.g4dn.16xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/x-image"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "application/x-image", + "default_accept_type": "application/json", + }, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/PennFudanPed_SemSeg/", + "validation_supported": False, + "fine_tuning_supported": True, + "resource_name_base": "mx-semseg-fcn-resnet50-ade", + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "ap-northeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "ap-south-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "ap-southeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "ap-southeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/mxnet-inference:1.9.0-gpu-py38", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/mxnet-inference:1.9.0-gpu-py38", + }, + "eu-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "eu-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "sa-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "us-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "us-east-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + "us-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.9.0-gpu-py38", + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "ap-northeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "ap-south-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "ap-southeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "ap-southeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/mxnet-training:1.9.0-gpu-py38", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/mxnet-training:1.9.0-gpu-py38", + }, + "eu-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "eu-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "sa-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "us-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "us-east-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, + "us-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:1.9.0-gpu-py38", + }, }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, }, - ], - "inference_vulnerable": False, - "training_vulnerable": False, - "deprecated": False, - "default_training_instance_type": None, - "supported_training_instance_types": [], - "metrics": [], - "default_inference_instance_type": "ml.g5.12xlarge", - "supported_inference_instance_types": [ - "ml.g5.12xlarge", - "ml.g5.24xlarge", - "ml.p3.8xlarge", - "ml.p3.16xlarge", - "ml.g4dn.12xlarge", - ], + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, }, - "mock-model-training-prepacked-script-key": { - "model_id": "sklearn-classification-linear", - "url": "https://scikit-learn.org/stable/", - "version": "1.0.0", - "min_sdk_version": "2.68.1", + "huggingface-spc-bert-base-cased": { + "model_id": "huggingface-spc-bert-base-cased", + "url": "https://huggingface.co/bert-base-cased", + "version": "2.0.3", + "min_sdk_version": "2.189.0", "training_supported": True, - "incremental_training_supported": False, + "incremental_training_supported": True, "hosting_ecr_specs": { - "framework": "sklearn", - "framework_version": "0.23-1", - "py_version": "py3", + "framework": "huggingface", + "framework_version": "1.7.1", + "py_version": "py36", + "huggingface_transformers_version": "4.6.1", }, - "hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz", - "hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.0.0/sourcedir.tar.gz", + "hosting_artifact_key": "huggingface-spc/huggingface-spc-bert-base-cased/artifacts/inference/v1.2.0/", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/spc/v1.1.3/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-spc/huggingface-spc-bert-base-cased/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, "inference_vulnerable": False, "inference_dependencies": [], "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": [], + "training_dependencies": ["sagemaker_jumpstart_prepack_script_utilities==1.0.0"], "training_vulnerabilities": [], "deprecated": False, "hyperparameters": [ { - "name": "tol", + "name": "epochs", + "type": "int", + "default": 3, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "adam-learning-rate", "type": "float", - "default": 0.0001, - "min": 1e-20, - "max": 50, + "default": 2e-05, + "min": 1e-08, + "max": 1, "scope": "algorithm", }, { - "name": "penalty", - "type": "text", - "default": "l2", - "options": ["l1", "l2", "elasticnet", "none"], + "name": "batch-size", + "type": "int", + "default": 8, + "min": 1, + "max": 1024, "scope": "algorithm", }, { - "name": "alpha", - "type": "float", - "default": 0.0001, - "min": 1e-20, - "max": 999, + "name": "reinitialize-top-layer", + "type": "text", + "default": "Auto", + "options": ["Auto", "True", "False"], "scope": "algorithm", }, { - "name": "l1_ratio", - "type": "float", - "default": 0.15, - "min": 0, - "max": 1, + "name": "train-only-top-layer", + "type": "text", + "default": "False", + "options": ["True", "False"], "scope": "algorithm", }, { @@ -6205,245 +12653,542 @@ "scope": "container", }, ], - "training_script_key": "source-directory-tarballs/sklearn/transfer_learning/classification/" - "v1.0.0/sourcedir.tar.gz", - "training_prepacked_script_key": "some/key/to/training_prepacked_script_key.tar.gz", + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/spc/v1.3.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/huggingface/transfer_learning/spc/prepack/v1.1.0/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.0", "training_ecr_specs": { - "framework_version": "0.23-1", - "framework": "sklearn", - "py_version": "py3", + "framework": "huggingface", + "framework_version": "1.6.0", + "py_version": "py36", + "huggingface_transformers_version": "4.4.2", }, - "training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz", + "training_artifact_key": "huggingface-training/train-huggingface-spc-bert-base-cased.tar.gz", "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", "type": "text", "default": "inference.py", "scope": "container", + "required_for_model_class": True, }, { "name": "SAGEMAKER_SUBMIT_DIRECTORY", "type": "text", "default": "/opt/ml/model/code", "scope": "container", + "required_for_model_class": False, }, { "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", "type": "text", "default": "20", "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, }, { "name": "MODEL_CACHE_ROOT", "type": "text", "default": "/opt/ml/model", "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + {"Name": "hugginface-spc:eval-accuracy", "Regex": "'eval_accuracy': ([0-9\\.]+)"} + ], + "default_inference_instance_type": "ml.p3.2xlarge", + "supported_inference_instance_types": [ + "ml.p2.xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + ], + "default_training_instance_type": "ml.p3.2xlarge", + "supported_training_instance_types": [ + "ml.p3.2xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.p2.xlarge", + "ml.p2.8xlarge", + "ml.p2.16xlarge", + "ml.g4dn.xlarge", + "ml.g4dn.2xlarge", + "ml.g4dn.4xlarge", + "ml.g4dn.8xlarge", + "ml.g4dn.16xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/list-text"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "application/list-text", + "default_accept_type": "application/json", + }, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/QNLI/", + "validation_supported": False, + "fine_tuning_supported": True, + "resource_name_base": "hf-spc-bert-base-cased", + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "ap-northeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "ap-south-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "ap-southeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "ap-southeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "eu-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "eu-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "sa-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "us-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "us-east-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + "us-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-cpu-py36-ubuntu18.04", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:1.7.1-transformers4.6.1-gpu-py36-cu110-ubuntu18.04", + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "text", - "default": "1", - "scope": "container", + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "ap-northeast-3": { + "gpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "ap-southeast-3": { + "gpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "cn-northwest-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "eu-central-2": { + "gpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "il-central-1": { + "gpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "me-central-1": { + "gpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "us-gov-east-1": { + "gpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "us-gov-west-1": { + "gpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:1.6.0-transformers4.4.2-gpu-py36-cu110-ubuntu18.04" + }, }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, - ], + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, }, -} - - -PROTOTYPICAL_MODEL_SPECS_DICT = { - "pytorch-eqa-bert-base-cased": { - "model_id": "pytorch-eqa-bert-base-cased", - "url": "https://pytorch.org/hub/huggingface_pytorch-transformers/", - "version": "1.0.0", - "min_sdk_version": "2.68.1", + "lightgbm-classification-model": { + "model_id": "lightgbm-classification-model", + "url": "https://lightgbm.readthedocs.io/en/latest/", + "version": "2.1.6", + "min_sdk_version": "2.189.0", "training_supported": True, - "incremental_training_supported": False, + "incremental_training_supported": True, "hosting_ecr_specs": { "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "framework_version": "2.0.1", + "py_version": "py310", }, - "default_inference_instance_type": "ml.p2.xlarge", - "supported_inference_instance_type": ["ml.p2.xlarge", "ml.p3.xlarge"], - "default_training_instance_type": "ml.p2.xlarge", - "supported_training_instance_type": ["ml.p2.xlarge", "ml.p3.xlarge"], - "hosting_artifact_key": "pytorch-infer/infer-pytorch-eqa-bert-base-cased.tar.gz", - "hosting_script_key": "source-directory-tarballs/pytorch/inference/eqa/v1.0.0/sourcedir.tar.gz", + "hosting_artifact_key": "lightgbm-classification/lightgbm-classification-model/artifacts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/lightgbm/inference/classification/v1.2.2/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "lightgbm-classification/lightgbm-classification-model/artifacts/inference-prepack/v1.0.1/", + "hosting_prepacked_artifact_version": "1.0.1", + "hosting_use_script_uri": False, "inference_vulnerable": False, - "inference_dependencies": [ - "transformers==3.5.1", - "dataclasses==0.8", - "filelock==3.0.12", - "packaging==20.8", - "pyparsing==2.4.7", - "regex==2020.11.13", - "sacremoses==0.0.43", - "sentencepiece==0.1.91", - "tokenizers==0.9.3", - ], + "inference_dependencies": ["lightgbm==4.1.0"], "inference_vulnerabilities": [], "training_vulnerable": False, "training_dependencies": [ - "transformers==3.5.1", - "dataclasses==0.8", - "filelock==3.0.12", - "packaging==20.8", - "pyparsing==2.4.7", - "regex==2020.11.13", - "sacremoses==0.0.43", - "sentencepiece==0.1.91", - "tokenizers==0.9.3", + "HeapDict==1.0.1", + "dask==2022.12.1", + "distributed==2022.12.1", + "graphviz==0.17", + "lightgbm==3.3.3", + "locket==1.0.0", + "msgpack==1.0.4", + "partd==1.3.0", + "sagemaker_jumpstart_prepack_script_utilities==1.0.0", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "sortedcontainers==2.4.0", + "tblib==1.7.0", + "toolz==0.12.0", + "zict==2.2.0", ], "training_vulnerabilities": [], "deprecated": False, "hyperparameters": [ { - "name": "epochs", + "name": "num_boost_round", "type": "int", - "default": 3, + "default": 5000, "min": 1, - "max": 1000, + "max": 100000, "scope": "algorithm", }, + {"name": "early_stopping_rounds", "type": "int", "default": 30, "scope": "algorithm"}, + {"name": "metric", "type": "text", "default": "auto", "scope": "algorithm"}, { - "name": "adam-learning-rate", + "name": "learning_rate", "type": "float", - "default": 2e-05, - "min": 1e-08, - "max": 1, + "default": 0.009, + "min": 1e-20, "scope": "algorithm", }, { - "name": "batch-size", + "name": "num_leaves", "type": "int", - "default": 4, - "min": 1, - "max": 1024, + "default": 67, + "min": 2, + "max": 131072, "scope": "algorithm", }, { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", + "name": "feature_fraction", + "type": "float", + "default": 0.74, + "min": 1e-20, + "max": 1, + "scope": "algorithm", }, { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", + "name": "bagging_fraction", + "type": "float", + "default": 0.53, + "min": 1e-20, + "max": 1, + "scope": "algorithm", }, { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", + "name": "bagging_freq", + "type": "int", + "default": 5, + "min": 0, + "max": 100000, + "scope": "algorithm", }, - ], - "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/eqa/v1.0.0/sourcedir.tar.gz", - "training_ecr_specs": { - "framework_version": "1.5.0", - "framework": "pytorch", - "py_version": "py3", - }, - "training_artifact_key": "pytorch-training/train-pytorch-eqa-bert-base-cased.tar.gz", - "predictor_specs": { - "supported_content_types": ["application/x-image"], - "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-image", - "default_accept_type": "application/json", - }, - "inference_environment_variables": [ + {"name": "max_depth", "type": "int", "default": 11, "scope": "algorithm"}, { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", + "name": "min_data_in_leaf", + "type": "int", + "default": 26, + "min": 0, + "scope": "algorithm", }, + {"name": "max_delta_step", "type": "float", "default": 0.0, "scope": "algorithm"}, { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", + "name": "lambda_l1", + "type": "float", + "default": 0.0, + "min": 0.0, + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", + "name": "lambda_l2", + "type": "float", + "default": 0.0, + "min": 0.0, + "scope": "algorithm", }, { - "name": "MODEL_CACHE_ROOT", + "name": "boosting", "type": "text", - "default": "/opt/ml/model", - "scope": "container", + "default": "gbdt", + "options": ["gbdt", "rf", "dart", "goss"], + "scope": "algorithm", }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "text", - "default": "1", - "scope": "container", + "name": "min_gain_to_split", + "type": "float", + "default": 0.0, + "min": 0.0, + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", + "name": "scale_pos_weight", + "type": "float", + "default": 1.0, + "min": 1e-20, + "scope": "algorithm", }, - ], - }, - "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1": { - "model_id": "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", - "url": "https://tfhub.dev/google/bit/m-r101x1/ilsvrc2012_classification/1", - "version": "1.0.0", - "min_sdk_version": "2.68.1", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "tensorflow", - "framework_version": "2.3", - "py_version": "py37", - }, - "hosting_artifact_key": "tensorflow-infer/infer-tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1.tar.gz", - "hosting_script_key": "source-directory-tarballs/tensorflow/inference/ic/v1.0.0/sourcedir.tar.gz", - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "hyperparameters": [ { - "name": "epochs", - "type": "int", - "default": 3, - "min": 1, - "max": 1000, + "name": "tree_learner", + "type": "text", + "default": "serial", + "options": ["serial", "feature", "data", "voting"], "scope": "algorithm", }, { - "name": "adam-learning-rate", + "name": "feature_fraction_bynode", "type": "float", - "default": 0.05, - "min": 1e-08, + "default": 1.0, + "min": 1e-20, "max": 1, "scope": "algorithm", }, { - "name": "batch-size", - "type": "int", - "default": 4, - "min": 1, - "max": 1024, + "name": "is_unbalance", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + {"name": "max_bin", "type": "int", "default": 255, "min": 2, "scope": "algorithm"}, + {"name": "num_threads", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + {"name": "verbosity", "type": "int", "default": 1, "scope": "algorithm"}, + { + "name": "use_dask", + "type": "text", + "default": "False", + "options": ["True", "False"], "scope": "algorithm", }, { @@ -6465,100 +13210,624 @@ "scope": "container", }, ], - "training_script_key": "source-directory-tarballs/tensorflow/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/lightgbm/transfer_learning/classification/v2.2.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/lightgbm/transfer_learning/classification/prepack/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.1", "training_ecr_specs": { - "framework_version": "2.3", - "framework": "tensorflow", - "py_version": "py37", + "framework": "pytorch", + "framework_version": "1.9.0", + "py_version": "py38", }, - "training_artifact_key": "tensorflow-training/train-tensorflow-ic-bit-" - "m-r101x1-ilsvrc2012-classification-1.tar.gz", + "training_artifact_key": "lightgbm-training/train-lightgbm-classification-model.tar.gz", "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", "type": "text", "default": "inference.py", "scope": "container", + "required_for_model_class": True, }, { "name": "SAGEMAKER_SUBMIT_DIRECTORY", "type": "text", "default": "/opt/ml/model/code", "scope": "container", + "required_for_model_class": False, }, { "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", "type": "text", "default": "20", "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, }, { "name": "MODEL_CACHE_ROOT", "type": "text", "default": "/opt/ml/model", "scope": "container", + "required_for_model_class": True, }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container", + "required_for_model_class": True, }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, "scope": "container", + "required_for_model_class": True, }, ], - }, - "mxnet-semseg-fcn-resnet50-ade": { - "model_id": "mxnet-semseg-fcn-resnet50-ade", - "url": "https://cv.gluon.ai/model_zoo/segmentation.html", - "version": "1.0.0", - "min_sdk_version": "2.68.1", + "metrics": [ + { + "Name": "lightgbm-classification:multi-log-loss", + "Regex": "multi_logloss: ([0-9\\.]+)", + } + ], + "default_inference_instance_type": "ml.m5.4xlarge", + "supported_inference_instance_types": [ + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + "ml.m4.16xlarge", + ], + "default_training_instance_type": "ml.m5.12xlarge", + "supported_training_instance_types": [ + "ml.m5.xlarge", + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + "ml.m4.16xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["text/csv"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "text/csv", + "default_accept_type": "application/json", + }, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/tabular_multiclass/", + "validation_supported": True, + "fine_tuning_supported": False, + "resource_name_base": "lgb-classification-model", + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-northeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-south-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-southeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-southeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-inference:2.0.1-gpu-py310", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "sa-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-east-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "c6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "c6gn": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "m6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "r6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-northeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-south-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-southeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-southeast-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-training:1.9.0-gpu-py38", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.9.0-gpu-py38", + }, + "eu-central-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-west-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "sa-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-east-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-west-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, + }, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, + }, + "catboost-classification-model": { + "model_id": "catboost-classification-model", + "url": "https://catboost.ai/", + "version": "2.1.6", + "min_sdk_version": "2.189.0", "training_supported": True, - "incremental_training_supported": False, + "incremental_training_supported": True, "hosting_ecr_specs": { - "framework": "mxnet", - "framework_version": "1.7.0", - "py_version": "py3", + "framework": "pytorch", + "framework_version": "2.0.1", + "py_version": "py310", }, - "hosting_artifact_key": "mxnet-infer/infer-mxnet-semseg-fcn-resnet50-ade.tar.gz", - "hosting_script_key": "source-directory-tarballs/mxnet/inference/semseg/v1.0.0/sourcedir.tar.gz", + "hosting_artifact_key": "catboost-classification/catboost-classification-model/artifacts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/catboost/inference/classification/v1.1.2/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "catboost-classification/catboost-classification-model/artifacts/inference-prepack/v1.0.1/", + "hosting_prepacked_artifact_version": "1.0.1", + "hosting_use_script_uri": False, "inference_vulnerable": False, - "inference_dependencies": [], + "inference_dependencies": [ + "catboost==1.2.2", + "graphviz==0.20.1", + "plotly==5.18.0", + "tenacity==8.2.3", + ], "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": ["numpy==1.19.5", "opencv_python==4.0.1.23"], + "training_dependencies": [ + "catboost==1.0.1", + "graphviz==0.17", + "plotly==5.1.0", + "sagemaker_jumpstart_prepack_script_utilities==1.0.0", + "sagemaker_jumpstart_script_utilities==1.0.1", + "tenacity==8.0.1", + ], "training_vulnerabilities": [], "deprecated": False, "hyperparameters": [ { - "name": "epochs", + "name": "iterations", + "type": "int", + "default": 500, + "min": 1, + "max": 100000, + "scope": "algorithm", + }, + { + "name": "early_stopping_rounds", "type": "int", "default": 5, "min": 1, - "max": 1000, + "max": 5000, "scope": "algorithm", }, + {"name": "eval_metric", "type": "text", "default": "Auto", "scope": "algorithm"}, { - "name": "adam-learning-rate", + "name": "learning_rate", "type": "float", - "default": 0.001, - "min": 1e-08, + "default": 0.03, + "min": 1e-20, "max": 1, "scope": "algorithm", }, { - "name": "batch-size", + "name": "depth", "type": "int", - "default": 2, + "default": 6, "min": 1, - "max": 1024, + "max": 16, + "scope": "algorithm", + }, + { + "name": "l2_leaf_reg", + "type": "int", + "default": 3, + "min": 1, + "max": 10000, + "scope": "algorithm", + }, + { + "name": "random_strength", + "type": "float", + "default": 1.0, + "min": 1e-20, + "max": 10, + "scope": "algorithm", + }, + {"name": "max_leaves", "type": "int", "default": 31, "min": 2, "scope": "algorithm"}, + { + "name": "rsm", + "type": "float", + "default": 1, + "min": 1e-20, + "max": 1, + "scope": "algorithm", + }, + { + "name": "sampling_frequency", + "type": "text", + "default": "PerTreeLevel", + "options": ["PerTreeLevel", "PerTree"], + "scope": "algorithm", + }, + { + "name": "min_data_in_leaf", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "bagging_temperature", + "type": "float", + "default": 1, + "min": 0, + "scope": "algorithm", + }, + { + "name": "boosting_type", + "type": "text", + "default": "Auto", + "options": ["Auto", "Ordered", "Plain"], + "scope": "algorithm", + }, + { + "name": "scale_pos_weight", + "type": "float", + "default": 1.0, + "min": 1e-20, + "scope": "algorithm", + }, + {"name": "max_bin", "type": "text", "default": "Auto", "scope": "algorithm"}, + { + "name": "grow_policy", + "type": "text", + "default": "SymmetricTree", + "options": ["SymmetricTree", "Depthwise", "Lossguide"], "scope": "algorithm", }, + {"name": "random_seed", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + {"name": "thread_count", "type": "int", "default": -1, "min": -1, "scope": "algorithm"}, + {"name": "verbose", "type": "int", "default": 1, "min": 1, "scope": "algorithm"}, { "name": "sagemaker_submit_directory", "type": "text", @@ -6578,199 +13847,552 @@ "scope": "container", }, ], - "training_script_key": "source-directory-tarballs/mxnet/transfer_learning/semseg/v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/catboost/transfer_learning/classification/v1.2.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/catboost/transfer_learning/classification/prepack/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_version": "1.1.1", "training_ecr_specs": { - "framework_version": "1.7.0", - "framework": "mxnet", - "py_version": "py3", + "framework": "pytorch", + "framework_version": "1.9.0", + "py_version": "py38", }, - "training_artifact_key": "mxnet-training/train-mxnet-semseg-fcn-resnet50-ade.tar.gz", + "training_artifact_key": "catboost-training/train-catboost-classification-model.tar.gz", "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", "type": "text", "default": "inference.py", "scope": "container", + "required_for_model_class": True, }, { "name": "SAGEMAKER_SUBMIT_DIRECTORY", "type": "text", "default": "/opt/ml/model/code", "scope": "container", + "required_for_model_class": False, }, { "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", "type": "text", "default": "20", "scope": "container", - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "text", - "default": "1", - "scope": "container", + "required_for_model_class": False, }, { "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", "type": "text", "default": "3600", "scope": "container", + "required_for_model_class": False, }, - ], - }, - "huggingface-spc-bert-base-cased": { - "model_id": "huggingface-spc-bert-base-cased", - "url": "https://huggingface.co/bert-base-cased", - "version": "1.0.0", - "min_sdk_version": "2.68.1", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "huggingface", - "framework_version": "1.7.1", - "py_version": "py36", - "huggingface_transformers_version": "4.6.1", - }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-spc-bert-base-cased.tar.gz", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/spc/v1.0.0/sourcedir.tar.gz", - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "hyperparameters": [ { - "name": "epochs", + "name": "ENDPOINT_SERVER_TIMEOUT", "type": "int", - "default": 3, - "min": 1, - "max": 1000, - "scope": "algorithm", - }, - { - "name": "adam-learning-rate", - "type": "float", - "default": 2e-05, - "min": 1e-08, - "max": 1, - "scope": "algorithm", + "default": 3600, + "scope": "container", + "required_for_model_class": True, }, { - "name": "batch-size", - "type": "int", - "default": 8, - "min": 1, - "max": 1024, - "scope": "algorithm", + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, }, { - "name": "sagemaker_submit_directory", + "name": "SAGEMAKER_ENV", "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "default": "1", "scope": "container", + "required_for_model_class": True, }, { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + {"Name": "catboost-classification:multi-logloss", "Regex": "multi_logloss: ([0-9\\.]+)"} + ], + "default_inference_instance_type": "ml.m5.4xlarge", + "supported_inference_instance_types": [ + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + "ml.m4.16xlarge", + ], + "default_training_instance_type": "ml.m5.12xlarge", + "supported_training_instance_types": [ + "ml.m5.xlarge", + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + "ml.m4.16xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["text/csv"], + "supported_accept_types": ["application/json", "application/json;verbose"], + "default_content_type": "text/csv", + "default_accept_type": "application/json", + }, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/tabular_multiclass/", + "validation_supported": True, + "fine_tuning_supported": False, + "resource_name_base": "cat-classification-model", + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-northeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-south-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-southeast-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-southeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-inference:2.0.1-gpu-py310", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "sa-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-east-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-east-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, + "us-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310", + "cpu_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-graviton:2.0.1-cpu-py310-ubuntu20.04-sagemaker", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-gpu-py310", + }, }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "c6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "c6gn": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "m6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "r6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_3"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, }, - ], - "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/spc/v1.0.0/sourcedir.tar.gz", - "training_ecr_specs": { - "framework_version": "1.6.0", - "framework": "huggingface", - "huggingface_transformers_version": "4.4.2", - "py_version": "py36", }, - "training_artifact_key": "huggingface-training/train-huggingface-spc-bert-base-cased.tar.gz", - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "text", - "default": "1", - "scope": "container", + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-northeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-south-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-southeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-southeast-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-training:1.9.0-gpu-py38", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.9.0-gpu-py38", + }, + "eu-central-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-west-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "sa-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-east-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, + "us-west-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38", + }, }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, }, - ], + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "dynamic_container_deployment_supported": False, }, - "lightgbm-classification-model": { - "model_id": "lightgbm-classification-model", - "url": "https://lightgbm.readthedocs.io/en/latest/", - "version": "1.0.0", - "min_sdk_version": "2.68.1", + "xgboost-classification-model": { + "model_id": "xgboost-classification-model", + "url": "https://xgboost.readthedocs.io/en/release_1.7.0/", + "version": "2.1.1", + "min_sdk_version": "2.188.0", "training_supported": True, - "incremental_training_supported": False, + "incremental_training_supported": True, "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.9.0", - "py_version": "py38", + "framework": "xgboost", + "framework_version": "1.7-1", + "py_version": "py3", }, - "hosting_artifact_key": "lightgbm-infer/infer-lightgbm-classification-model.tar.gz", - "hosting_script_key": "source-directory-tarballs/lightgbm/inference/classification/v1.0.0/sourcedir.tar.gz", + "hosting_artifact_key": "xgboost-infer/infer-xgboost-classification-model.tar.gz", + "hosting_script_key": "source-directory-tarballs/xgboost/inference/classification/v1.1.0/sourcedir.tar.gz", + "hosting_use_script_uri": True, "inference_vulnerable": False, - "inference_dependencies": [ - "plotly==5.1.0", - "joblib==1.0.1", - "scikit_learn==1.0.1", - "tenacity==8.0.1", - "lightgbm==3.2.1", - "threadpoolctl==2.2.0", - "graphviz==0.17", - ], + "inference_dependencies": [], "inference_vulnerabilities": [], "training_vulnerable": False, "training_dependencies": [ - "tenacity==8.0.1", - "plotly==5.1.0", - "graphviz==0.17", - "glibc==0.6.1", - "lightgbm==3.2.1", + "asn1crypto==1.5.1", + "attrs==23.1.0", + "boto3==1.26.158", + "botocore==1.29.159", + "certifi==2023.5.7", + "cffi==1.15.1", + "charset-normalizer==2.1.1", + "cloudpickle==2.2.1", + "contextlib2==21.6.0", + "cryptography==40.0.2", + "dill==0.3.6", + "filelock==3.12.2", + "google-pasta==0.2.0", + "idna==3.4", + "importlib-metadata==4.13.0", + "importlib-resources==5.12.0", + "jmespath==1.0.1", + "jsonschema==4.17.3", + "multiprocess==0.70.14", + "numpy==1.26.4", + "oscrypto==1.3.0", + "packaging==23.1", + "pandas==2.0.2", + "pathos==0.3.0", + "pkgutil-resolve-name==1.3.10", + "platformdirs==3.8.0", + "pox==0.3.2", + "ppft==1.7.6.6", + "protobuf3-to-dict==0.1.5", + "protobuf==3.20.3", + "pycparser==2.21", + "pycryptodomex==3.12.0", + "pyjwt==2.7.0", + "pyopenssl==23.2.0", + "pyrsistent==0.19.3", + "python-dateutil==2.8.2", + "pytz==2023.3", + "pyyaml==6.0", + "requests==2.31.0", + "s3transfer==0.6.1", + "sagemaker==2.164.0", + "sagemaker_jumpstart_script_utilities==1.0.1", + "sagemaker_jumpstart_snowflake_script_utilities==1.1.0", + "schema==0.7.5", + "six==1.16.0", + "smdebug-rulesconfig==1.0.1", + "snowflake-connector-python==3.12.3", + "tblib==1.7.0", + "typing-extensions==4.6.3", + "tzdata==2023.3", + "urllib3==1.26.16", + "zipp==3.15.0", ], "training_vulnerabilities": [], "deprecated": False, @@ -6780,55 +14402,64 @@ "type": "int", "default": 5000, "min": 1, - "max": 100000, + "max": 700000, + "scope": "algorithm", + }, + { + "name": "early_stopping_rounds", + "type": "int", + "default": 30, + "min": 1, + "max": 5000, "scope": "algorithm", }, - {"name": "early_stopping_rounds", "type": "int", "default": 30, "scope": "algorithm"}, { "name": "learning_rate", "type": "float", - "default": 0.009, + "default": 0.3, "min": 1e-20, + "max": 1, "scope": "algorithm", }, + {"name": "gamma", "type": "float", "default": 0, "min": 0, "scope": "algorithm"}, { - "name": "num_leaves", - "type": "int", - "default": 67, - "min": 2, - "max": 131072, + "name": "min_child_weight", + "type": "float", + "default": 1, + "min": 0, "scope": "algorithm", }, + {"name": "max_depth", "type": "int", "default": 6, "min": 1, "scope": "algorithm"}, { - "name": "feature_fraction", + "name": "subsample", "type": "float", - "default": 0.74, + "default": 1, "min": 1e-20, "max": 1, "scope": "algorithm", }, { - "name": "bagging_fraction", + "name": "colsample_bytree", "type": "float", - "default": 0.53, + "default": 1, "min": 1e-20, "max": 1, "scope": "algorithm", }, { - "name": "bagging_freq", - "type": "int", - "default": 5, + "name": "reg_lambda", + "type": "float", + "default": 1, "min": 0, - "max": 100000, + "max": 200, "scope": "algorithm", }, - {"name": "max_depth", "type": "int", "default": 11, "scope": "algorithm"}, { - "name": "min_data_in_leaf", - "type": "int", - "default": 26, + "name": "reg_alpha", + "type": "float", + "default": 0, "min": 0, + "max": 200, "scope": "algorithm", }, { @@ -6850,292 +14481,475 @@ "scope": "container", }, ], - "training_script_key": "source-directory-tarballs/lightgbm/transfer_learning/classification/" - "v1.0.0/sourcedir.tar.gz", - "training_ecr_specs": { - "framework_version": "1.9.0", - "framework": "pytorch", - "py_version": "py38", + "training_script_key": "source-directory-tarballs/training/xgboost-classification/v1.3.1/sourcedir.tar.gz", + "training_ecr_specs": { + "framework": "xgboost", + "framework_version": "1.7-1", + "py_version": "py3", }, - "training_artifact_key": "lightgbm-training/train-lightgbm-classification-model.tar.gz", + "training_artifact_key": "xgboost-training/train-xgboost-classification-model.tar.gz", "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", "type": "text", "default": "inference.py", "scope": "container", + "required_for_model_class": True, }, { "name": "SAGEMAKER_SUBMIT_DIRECTORY", "type": "text", "default": "/opt/ml/model/code", "scope": "container", + "required_for_model_class": False, }, { "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", "type": "text", "default": "20", "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, }, { "name": "MODEL_CACHE_ROOT", "type": "text", "default": "/opt/ml/model", "scope": "container", + "required_for_model_class": True, }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container", + "required_for_model_class": True, }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, "scope": "container", + "required_for_model_class": True, }, ], - }, - "catboost-classification-model": { - "model_id": "catboost-classification-model", - "url": "https://catboost.ai/", - "version": "1.0.0", - "min_sdk_version": "2.68.1", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.9.0", - "py_version": "py38", - }, - "hosting_artifact_key": "catboost-infer/infer-catboost-classification-model.tar.gz", - "hosting_script_key": "source-directory-tarballs/catboost/inference/classification/v1.0.0/sourcedir.tar.gz", - "inference_vulnerable": False, - "inference_dependencies": [ - "tenacity==8.0.1", - "plotly==5.1.0", - "graphviz==0.17", - "pyparsing==2.4.7", - "cycler==0.10.0", - "kiwisolver==1.3.2", - "matplotlib==3.4.3", - "catboost==1.0.1", - "scikit_learn==1.0.1", - "threadpoolctl==2.2.0", + "metrics": [], + "default_inference_instance_type": "ml.m5.xlarge", + "supported_inference_instance_types": [ + "ml.m5.xlarge", + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + "ml.c4.8xlarge", ], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [ - "tenacity==8.0.1", - "plotly==5.1.0", - "graphviz==0.17", - "catboost==1.0.1", + "default_training_instance_type": "ml.m5.4xlarge", + "supported_training_instance_types": [ + "ml.m5.xlarge", + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + "ml.c4.8xlarge", ], - "training_vulnerabilities": [], - "deprecated": False, - "hyperparameters": [ - { - "name": "iterations", - "type": "int", - "default": 500, - "min": 1, - "max": 100000, - "scope": "algorithm", - }, - { - "name": "early_stopping_rounds", - "type": "int", - "default": 5, - "min": 1, - "max": 5000, - "scope": "algorithm", - }, - { - "name": "learning_rate", - "type": "float", - "default": 0.03, - "min": 1e-20, - "max": 1, - "scope": "algorithm", - }, - { - "name": "depth", - "type": "int", - "default": 6, - "min": 1, - "max": 16, - "scope": "algorithm", - }, - { - "name": "l2_leaf_reg", - "type": "int", - "default": 3, - "min": 1, - "max": 10000, - "scope": "algorithm", - }, - { - "name": "random_strength", - "type": "float", - "default": 1.0, - "min": 1e-20, - "max": 10, - "scope": "algorithm", + "model_kwargs": {}, + "deploy_kwargs": {}, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["text/csv"], + "supported_accept_types": ["application/json", "application/json;verbose"], + "default_content_type": "text/csv", + "default_accept_type": "application/json", + }, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/tabular_multiclass/", + "validation_supported": True, + "fine_tuning_supported": False, + "resource_name_base": "xgb-classification-model", + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "xgb_ecr_uri_1": "510948584623.dkr.ecr.af-south-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-east-1": { + "xgb_ecr_uri_1": "651117190479.dkr.ecr.ap-east-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-northeast-1": { + "xgb_ecr_uri_1": "354813040037.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-northeast-2": { + "xgb_ecr_uri_1": "366743142698.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-northeast-3": { + "xgb_ecr_uri_1": "867004704886.dkr.ecr.ap-northeast-3.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-south-1": { + "xgb_ecr_uri_1": "720646828776.dkr.ecr.ap-south-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-southeast-1": { + "xgb_ecr_uri_1": "121021644041.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-southeast-2": { + "xgb_ecr_uri_1": "783357654285.dkr.ecr.ap-southeast-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-southeast-3": { + "xgb_ecr_uri_1": "951798379941.dkr.ecr.ap-southeast-3.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ca-central-1": { + "xgb_ecr_uri_1": "341280168497.dkr.ecr.ca-central-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "cn-north-1": { + "xgb_ecr_uri_1": "450853457545.dkr.ecr.cn-north-1.amazonaws.com.cn/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "cn-northwest-1": { + "xgb_ecr_uri_1": "451049120500.dkr.ecr.cn-northwest-1.amazonaws.com.cn/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-central-1": { + "xgb_ecr_uri_1": "492215442770.dkr.ecr.eu-central-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-central-2": { + "xgb_ecr_uri_1": "680994064768.dkr.ecr.eu-central-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-north-1": { + "xgb_ecr_uri_1": "662702820516.dkr.ecr.eu-north-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-south-1": { + "xgb_ecr_uri_1": "978288397137.dkr.ecr.eu-south-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-west-1": { + "xgb_ecr_uri_1": "141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-west-2": { + "xgb_ecr_uri_1": "764974769150.dkr.ecr.eu-west-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-west-3": { + "xgb_ecr_uri_1": "659782779980.dkr.ecr.eu-west-3.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "il-central-1": { + "xgb_ecr_uri_1": "898809789911.dkr.ecr.il-central-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "me-central-1": { + "xgb_ecr_uri_1": "272398656194.dkr.ecr.me-central-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "me-south-1": { + "xgb_ecr_uri_1": "801668240914.dkr.ecr.me-south-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "sa-east-1": { + "xgb_ecr_uri_1": "737474898029.dkr.ecr.sa-east-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-east-1": { + "xgb_ecr_uri_1": "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-east-2": { + "xgb_ecr_uri_1": "257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-gov-east-1": { + "xgb_ecr_uri_1": "237065988967.dkr.ecr.us-gov-east-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-gov-west-1": { + "xgb_ecr_uri_1": "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-west-1": { + "xgb_ecr_uri_1": "746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-west-2": { + "xgb_ecr_uri_1": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, }, - { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", + "variants": { + "c4": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "g6": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "g6e": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "inf1": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "inf2": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "xgb_ecr_uri_1": "510948584623.dkr.ecr.af-south-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-east-1": { + "xgb_ecr_uri_1": "651117190479.dkr.ecr.ap-east-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-northeast-1": { + "xgb_ecr_uri_1": "354813040037.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-northeast-2": { + "xgb_ecr_uri_1": "366743142698.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-northeast-3": { + "xgb_ecr_uri_1": "867004704886.dkr.ecr.ap-northeast-3.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-south-1": { + "xgb_ecr_uri_1": "720646828776.dkr.ecr.ap-south-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-southeast-1": { + "xgb_ecr_uri_1": "121021644041.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-southeast-2": { + "xgb_ecr_uri_1": "783357654285.dkr.ecr.ap-southeast-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ap-southeast-3": { + "xgb_ecr_uri_1": "951798379941.dkr.ecr.ap-southeast-3.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "ca-central-1": { + "xgb_ecr_uri_1": "341280168497.dkr.ecr.ca-central-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "cn-north-1": { + "xgb_ecr_uri_1": "450853457545.dkr.ecr.cn-north-1.amazonaws.com.cn/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "cn-northwest-1": { + "xgb_ecr_uri_1": "451049120500.dkr.ecr.cn-northwest-1.amazonaws.com.cn/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-central-1": { + "xgb_ecr_uri_1": "492215442770.dkr.ecr.eu-central-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-central-2": { + "xgb_ecr_uri_1": "680994064768.dkr.ecr.eu-central-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-north-1": { + "xgb_ecr_uri_1": "662702820516.dkr.ecr.eu-north-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-south-1": { + "xgb_ecr_uri_1": "978288397137.dkr.ecr.eu-south-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-west-1": { + "xgb_ecr_uri_1": "141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-west-2": { + "xgb_ecr_uri_1": "764974769150.dkr.ecr.eu-west-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "eu-west-3": { + "xgb_ecr_uri_1": "659782779980.dkr.ecr.eu-west-3.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "il-central-1": { + "xgb_ecr_uri_1": "898809789911.dkr.ecr.il-central-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "me-central-1": { + "xgb_ecr_uri_1": "272398656194.dkr.ecr.me-central-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "me-south-1": { + "xgb_ecr_uri_1": "801668240914.dkr.ecr.me-south-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "sa-east-1": { + "xgb_ecr_uri_1": "737474898029.dkr.ecr.sa-east-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-east-1": { + "xgb_ecr_uri_1": "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-east-2": { + "xgb_ecr_uri_1": "257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-gov-east-1": { + "xgb_ecr_uri_1": "237065988967.dkr.ecr.us-gov-east-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-gov-west-1": { + "xgb_ecr_uri_1": "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-west-1": { + "xgb_ecr_uri_1": "746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, + "us-west-2": { + "xgb_ecr_uri_1": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost@sha256:ba417ec6d8d3e0c6b5f463bc9202e3b498b42260a29b61875f34beb6d99d8444" + }, }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", + "variants": { + "c4": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "g6": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "g6e": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "inf1": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "inf2": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$xgb_ecr_uri_1"}}, }, - ], - "training_script_key": "source-directory-tarballs/catboost/transfer_learning/" - "classification/v1.0.0/sourcedir.tar.gz", - "training_ecr_specs": { - "framework_version": "1.9.0", - "framework": "pytorch", - "py_version": "py38", }, - "training_artifact_key": "catboost-training/train-catboost-classification-model.tar.gz", - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "text", - "default": "1", - "scope": "container", - }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", - }, - ], + "dynamic_container_deployment_supported": False, }, - "xgboost-classification-model": { - "model_id": "xgboost-classification-model", - "url": "https://xgboost.readthedocs.io/en/latest/", - "version": "1.0.0", - "min_sdk_version": "2.68.1", + "sklearn-classification-linear": { + "model_id": "sklearn-classification-linear", + "url": "https://scikit-learn.org/stable/", + "version": "1.3.1", + "min_sdk_version": "2.188.0", "training_supported": True, "incremental_training_supported": False, "hosting_ecr_specs": { - "framework": "xgboost", - "framework_version": "1.3-1", + "framework": "sklearn", + "framework_version": "1.2-1", "py_version": "py3", }, - "hosting_artifact_key": "xgboost-infer/infer-xgboost-classification-model.tar.gz", - "hosting_script_key": "source-directory-tarballs/xgboost/inference/classification/v1.0.0/sourcedir.tar.gz", + "hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz", + "hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.1.0/sourcedir.tar.gz", + "hosting_use_script_uri": True, "inference_vulnerable": False, "inference_dependencies": [], "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": [], + "training_dependencies": [ + "asn1crypto==1.5.1", + "attrs==23.1.0", + "boto3==1.26.158", + "botocore==1.29.159", + "certifi==2023.5.7", + "cffi==1.15.1", + "charset-normalizer==2.1.1", + "cloudpickle==2.2.1", + "contextlib2==21.6.0", + "cryptography==40.0.2", + "dill==0.3.6", + "filelock==3.12.2", + "google-pasta==0.2.0", + "idna==3.4", + "importlib-metadata==4.13.0", + "importlib-resources==5.12.0", + "jmespath==1.0.1", + "jsonschema==4.17.3", + "multiprocess==0.70.14", + "numpy==1.24.3", + "oscrypto==1.3.0", + "packaging==23.1", + "pandas==2.0.2", + "pathos==0.3.0", + "pkgutil-resolve-name==1.3.10", + "platformdirs==3.8.0", + "pox==0.3.2", + "ppft==1.7.6.6", + "protobuf3-to-dict==0.1.5", + "protobuf==3.20.3", + "pycparser==2.21", + "pycryptodomex==3.12.0", + "pyjwt==2.7.0", + "pyopenssl==23.2.0", + "pyrsistent==0.19.3", + "python-dateutil==2.8.2", + "pytz==2023.3", + "pyyaml==6.0", + "requests==2.31.0", + "s3transfer==0.6.1", + "sagemaker==2.164.0", + "sagemaker_jumpstart_script_utilities==1.0.1", + "sagemaker_jumpstart_snowflake_script_utilities==1.1.0", + "schema==0.7.5", + "six==1.16.0", + "smdebug-rulesconfig==1.0.1", + "snowflake-connector-python==3.12.3", + "tblib==1.7.0", + "typing-extensions==4.6.3", + "tzdata==2023.3", + "urllib3==1.26.16", + "zipp==3.15.0", + ], "training_vulnerabilities": [], "deprecated": False, "hyperparameters": [ { - "name": "num_boost_round", - "type": "int", - "default": 5000, - "min": 1, - "max": 700000, - "scope": "algorithm", - }, - { - "name": "early_stopping_rounds", - "type": "int", - "default": 30, - "min": 1, - "max": 5000, - "scope": "algorithm", - }, - { - "name": "learning_rate", + "name": "tol", "type": "float", - "default": 0.3, + "default": 0.0001, "min": 1e-20, - "max": 1, - "scope": "algorithm", - }, - {"name": "gamma", "type": "float", "default": 0, "min": 0, "scope": "algorithm"}, - { - "name": "min_child_weight", - "type": "float", - "default": 1, - "min": 0, + "max": 50, "scope": "algorithm", }, - {"name": "max_depth", "type": "int", "default": 6, "min": 1, "scope": "algorithm"}, { - "name": "subsample", - "type": "float", - "default": 1, - "min": 1e-20, - "max": 1, + "name": "penalty", + "type": "text", + "default": "l2", + "options": ["l1", "l2", "elasticnet", "none"], "scope": "algorithm", }, { - "name": "colsample_bytree", + "name": "alpha", "type": "float", - "default": 1, + "default": 0.0001, "min": 1e-20, - "max": 1, - "scope": "algorithm", - }, - { - "name": "reg_lambda", - "type": "float", - "default": 1, - "min": 0, - "max": 200, + "max": 999, "scope": "algorithm", }, { - "name": "reg_alpha", + "name": "l1_ratio", "type": "float", - "default": 0, + "default": 0.15, "min": 0, - "max": 200, + "max": 1, "scope": "algorithm", }, { @@ -7157,753 +14971,2707 @@ "scope": "container", }, ], - "training_script_key": "source-directory-tarballs/xgboost/transfer_learning/classification/" - "v1.0.0/sourcedir.tar.gz", + "training_script_key": "source-directory-tarballs/training/sklearn-classification/v2.0.1/sourcedir.tar.gz", "training_ecr_specs": { - "framework_version": "1.3-1", - "framework": "xgboost", + "framework": "sklearn", + "framework_version": "1.2-1", "py_version": "py3", }, - "training_artifact_key": "xgboost-training/train-xgboost-classification-model.tar.gz", + "training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz", "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", "type": "text", "default": "inference.py", "scope": "container", + "required_for_model_class": True, }, { "name": "SAGEMAKER_SUBMIT_DIRECTORY", "type": "text", "default": "/opt/ml/model/code", "scope": "container", + "required_for_model_class": False, }, { "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", "type": "text", "default": "20", "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.m5.xlarge", + "supported_inference_instance_types": [ + "ml.m5.xlarge", + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + "ml.c4.8xlarge", + ], + "default_training_instance_type": "ml.m5.4xlarge", + "supported_training_instance_types": [ + "ml.m5.xlarge", + "ml.m5.2xlarge", + "ml.m5.4xlarge", + "ml.m5.12xlarge", + "ml.m5.24xlarge", + "ml.c5.2xlarge", + "ml.c5.4xlarge", + "ml.c5.9xlarge", + "ml.c5.18xlarge", + "ml.c4.8xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": {}, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["text/csv"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "text/csv", + "default_accept_type": "application/json", + }, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/tabular_multiclass/", + "validation_supported": True, + "fine_tuning_supported": False, + "resource_name_base": "sklearn-classification-linear", + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "scikit_ecr_uri_1": "510948584623.dkr.ecr.af-south-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" + }, + "ap-east-1": { + "cpu_ecr_uri_2": "651117190479.dkr.ecr.ap-east-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "651117190479.dkr.ecr.ap-east-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-northeast-1": { + "cpu_ecr_uri_2": "354813040037.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "354813040037.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-northeast-2": { + "cpu_ecr_uri_2": "366743142698.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "366743142698.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-northeast-3": { + "cpu_ecr_uri_2": "867004704886.dkr.ecr.ap-northeast-3.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "867004704886.dkr.ecr.ap-northeast-3.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-south-1": { + "cpu_ecr_uri_2": "720646828776.dkr.ecr.ap-south-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "720646828776.dkr.ecr.ap-south-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-southeast-1": { + "cpu_ecr_uri_2": "121021644041.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "121021644041.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-southeast-2": { + "cpu_ecr_uri_2": "783357654285.dkr.ecr.ap-southeast-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "783357654285.dkr.ecr.ap-southeast-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-southeast-3": { + "scikit_ecr_uri_1": "951798379941.dkr.ecr.ap-southeast-3.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" + }, + "ca-central-1": { + "cpu_ecr_uri_2": "341280168497.dkr.ecr.ca-central-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "341280168497.dkr.ecr.ca-central-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "cn-north-1": { + "scikit_ecr_uri_1": "450853457545.dkr.ecr.cn-north-1.amazonaws.com.cn/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" + }, + "cn-northwest-1": { + "scikit_ecr_uri_1": "451049120500.dkr.ecr.cn-northwest-1.amazonaws.com.cn/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" + }, + "eu-central-1": { + "cpu_ecr_uri_2": "492215442770.dkr.ecr.eu-central-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "492215442770.dkr.ecr.eu-central-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-central-2": { + "cpu_ecr_uri_2": "680994064768.dkr.ecr.eu-central-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "680994064768.dkr.ecr.eu-central-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-north-1": { + "cpu_ecr_uri_2": "662702820516.dkr.ecr.eu-north-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "662702820516.dkr.ecr.eu-north-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-south-1": { + "cpu_ecr_uri_2": "978288397137.dkr.ecr.eu-south-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "978288397137.dkr.ecr.eu-south-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-west-1": { + "cpu_ecr_uri_2": "141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-west-2": { + "cpu_ecr_uri_2": "764974769150.dkr.ecr.eu-west-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "764974769150.dkr.ecr.eu-west-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-west-3": { + "cpu_ecr_uri_2": "659782779980.dkr.ecr.eu-west-3.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "659782779980.dkr.ecr.eu-west-3.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "il-central-1": { + "cpu_ecr_uri_2": "898809789911.dkr.ecr.il-central-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "898809789911.dkr.ecr.il-central-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "me-central-1": { + "cpu_ecr_uri_2": "272398656194.dkr.ecr.me-central-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "272398656194.dkr.ecr.me-central-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "me-south-1": { + "cpu_ecr_uri_2": "801668240914.dkr.ecr.me-south-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "801668240914.dkr.ecr.me-south-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "sa-east-1": { + "cpu_ecr_uri_2": "737474898029.dkr.ecr.sa-east-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "737474898029.dkr.ecr.sa-east-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "us-east-1": { + "cpu_ecr_uri_2": "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "us-east-2": { + "cpu_ecr_uri_2": "257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "us-gov-east-1": { + "cpu_ecr_uri_2": "237065988967.dkr.ecr.us-gov-east-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "237065988967.dkr.ecr.us-gov-east-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "us-gov-west-1": { + "scikit_ecr_uri_1": "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" + }, + "us-west-1": { + "cpu_ecr_uri_2": "746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "us-west-2": { + "cpu_ecr_uri_2": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", + "variants": { + "c4": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "c6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "c6gn": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "c6i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c7g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "c7i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "m6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "m6i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "r6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "r6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "r6i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "text", - "default": "1", - "scope": "container", + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "scikit_ecr_uri_1": "510948584623.dkr.ecr.af-south-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" + }, + "ap-east-1": { + "cpu_ecr_uri_2": "651117190479.dkr.ecr.ap-east-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "651117190479.dkr.ecr.ap-east-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-northeast-1": { + "cpu_ecr_uri_2": "354813040037.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "354813040037.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-northeast-2": { + "cpu_ecr_uri_2": "366743142698.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "366743142698.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-northeast-3": { + "cpu_ecr_uri_2": "867004704886.dkr.ecr.ap-northeast-3.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "867004704886.dkr.ecr.ap-northeast-3.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-south-1": { + "cpu_ecr_uri_2": "720646828776.dkr.ecr.ap-south-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "720646828776.dkr.ecr.ap-south-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-southeast-1": { + "cpu_ecr_uri_2": "121021644041.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "121021644041.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-southeast-2": { + "cpu_ecr_uri_2": "783357654285.dkr.ecr.ap-southeast-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "783357654285.dkr.ecr.ap-southeast-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "ap-southeast-3": { + "scikit_ecr_uri_1": "951798379941.dkr.ecr.ap-southeast-3.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" + }, + "ca-central-1": { + "cpu_ecr_uri_2": "341280168497.dkr.ecr.ca-central-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "341280168497.dkr.ecr.ca-central-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "cn-north-1": { + "scikit_ecr_uri_1": "450853457545.dkr.ecr.cn-north-1.amazonaws.com.cn/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" + }, + "cn-northwest-1": { + "scikit_ecr_uri_1": "451049120500.dkr.ecr.cn-northwest-1.amazonaws.com.cn/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" + }, + "eu-central-1": { + "cpu_ecr_uri_2": "492215442770.dkr.ecr.eu-central-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "492215442770.dkr.ecr.eu-central-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-central-2": { + "cpu_ecr_uri_2": "680994064768.dkr.ecr.eu-central-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "680994064768.dkr.ecr.eu-central-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-north-1": { + "cpu_ecr_uri_2": "662702820516.dkr.ecr.eu-north-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "662702820516.dkr.ecr.eu-north-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-south-1": { + "cpu_ecr_uri_2": "978288397137.dkr.ecr.eu-south-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "978288397137.dkr.ecr.eu-south-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-west-1": { + "cpu_ecr_uri_2": "141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-west-2": { + "cpu_ecr_uri_2": "764974769150.dkr.ecr.eu-west-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "764974769150.dkr.ecr.eu-west-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "eu-west-3": { + "cpu_ecr_uri_2": "659782779980.dkr.ecr.eu-west-3.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "659782779980.dkr.ecr.eu-west-3.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "il-central-1": { + "cpu_ecr_uri_2": "898809789911.dkr.ecr.il-central-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "898809789911.dkr.ecr.il-central-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "me-central-1": { + "cpu_ecr_uri_2": "272398656194.dkr.ecr.me-central-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "272398656194.dkr.ecr.me-central-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "me-south-1": { + "cpu_ecr_uri_2": "801668240914.dkr.ecr.me-south-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "801668240914.dkr.ecr.me-south-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "sa-east-1": { + "cpu_ecr_uri_2": "737474898029.dkr.ecr.sa-east-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "737474898029.dkr.ecr.sa-east-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "us-east-1": { + "cpu_ecr_uri_2": "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "us-east-2": { + "cpu_ecr_uri_2": "257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "us-gov-east-1": { + "cpu_ecr_uri_2": "237065988967.dkr.ecr.us-gov-east-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "237065988967.dkr.ecr.us-gov-east-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "us-gov-west-1": { + "scikit_ecr_uri_1": "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95" + }, + "us-west-1": { + "cpu_ecr_uri_2": "746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, + "us-west-2": { + "cpu_ecr_uri_2": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:1.0-1-arm64-cpu-py3", + "scikit_ecr_uri_1": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn@sha256:e09bbb7686077a1db23d316b699020a786a6e1636b2b89384be9651368c40f95", + }, }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", + "variants": { + "c4": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "c6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "c6gn": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "c6i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "c7g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "c7i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "m6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "m6i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "r6g": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "r6gd": {"regional_properties": {"image_uri": "$cpu_ecr_uri_2"}}, + "r6i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$scikit_ecr_uri_1"}}, }, - ], + }, + "dynamic_container_deployment_supported": False, }, - "sklearn-classification-linear": { - "model_id": "sklearn-classification-linear", - "url": "https://scikit-learn.org/stable/", - "version": "1.0.0", - "min_sdk_version": "2.68.1", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "sklearn", - "framework_version": "0.23-1", - "py_version": "py3", +} + +BASE_SPEC = { + "hosting_resource_requirements": {"num_accelerators": 1, "min_memory_mb": 34360}, + "inference_volume_size": 123, + "training_volume_size": 456, + "dynamic_container_deployment_supported": True, + "model_id": "pytorch-ic-mobilenet-v2", + "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", + "version": "3.0.6", + "min_sdk_version": "2.189.0", + "incremental_training_supported": True, + "hosting_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.10.0", + "py_version": "py38", + }, + "hosting_artifact_uri": None, + "hosting_artifact_key": "pytorch-ic/pytorch-ic-mobilenet-v2/artifacts/inference/v2.0.0/", + "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v2.0.0/sourcedir.tar.gz", + "training_supported": True, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.10.0", + "py_version": "py38", + }, + "training_artifact_key": "pytorch-training/v2.0.0/train-pytorch-ic-mobilenet-v2.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v2.3.0/sourcedir.tar.gz", + "hyperparameters": [ + { + "name": "train_only_top_layer", + "type": "text", + "options": ["True", "False"], + "default": "True", + "scope": "algorithm", }, - "hosting_artifact_key": "sklearn-infer/infer-sklearn-classification-linear.tar.gz", - "hosting_script_key": "source-directory-tarballs/sklearn/inference/classification/v1.0.0/sourcedir.tar.gz", - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "hyperparameters": [ - { - "name": "tol", - "type": "float", - "default": 0.0001, - "min": 1e-20, - "max": 50, - "scope": "algorithm", - }, - { - "name": "penalty", - "type": "text", - "default": "l2", - "options": ["l1", "l2", "elasticnet", "none"], - "scope": "algorithm", - }, - { - "name": "alpha", - "type": "float", - "default": 0.0001, - "min": 1e-20, - "max": 999, - "scope": "algorithm", - }, - { - "name": "l1_ratio", - "type": "float", - "default": 0.15, - "min": 0, - "max": 1, - "scope": "algorithm", - }, - { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", - }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", - }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", + { + "name": "epochs", + "type": "int", + "default": 5, + "scope": "algorithm", + "min": 1, + "max": 1000, + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.001, + "scope": "algorithm", + "min": 1e-08, + "max": 1, + }, + { + "name": "batch_size", + "type": "int", + "default": 4, + "scope": "algorithm", + "min": 1, + "max": 1024, + }, + { + "name": "reinitialize_top_layer", + "type": "text", + "options": ["Auto", "True", "False"], + "default": "Auto", + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": ["sagemaker_jumpstart_prepack_script_utilities==1.0.0"], + "training_vulnerabilities": [], + "deprecated": False, + "usage_info_message": None, + "deprecated_message": None, + "deprecate_warn_message": None, + "default_inference_instance_type": "ml.m5.large", + "supported_inference_instance_types": [ + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + "ml.m4.large", + "ml.m4.xlarge", + ], + "default_training_instance_type": "ml.m5.xlarge", + "supported_training_instance_types": ["ml.m5.xlarge", "ml.c5.2xlarge", "ml.m4.xlarge"], + "metrics": [{"Name": "pytorch-ic:val-accuracy", "Regex": "val_accuracy: ([0-9\\.]+)"}], + "training_prepacked_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/prepack/v1.1.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "pytorch-ic/pytorch-ic-mobilenet-v2/artifacts/inference-prepack/v1.0.0/", + "model_kwargs": {}, + "deploy_kwargs": {}, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "default_content_type": "application/x-image", + "supported_content_types": ["application/x-image"], + "default_accept_type": "application/json", + "supported_accept_types": ["application/json;verbose", "application/json"], + }, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_uri": None, + "default_training_dataset_key": "training-datasets/tf_flowers/", + "resource_name_base": "pt-ic-mobilenet-v2", + "hosting_eula_key": None, + "hosting_model_package_arns": {}, + "training_model_package_artifact_uris": None, + "hosting_use_script_uri": False, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-northeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-south-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-inference:1.10.0-gpu-py38", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-central-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-west-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "sa-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-east-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", + }, + "us-west-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", }, - ], - "training_script_key": "source-directory-tarballs/sklearn/transfer_learning/classification/" - "v1.0.0/sourcedir.tar.gz", - "training_ecr_specs": { - "framework_version": "0.23-1", - "framework": "sklearn", - "py_version": "py3", }, - "training_artifact_key": "sklearn-training/train-sklearn-classification-linear.tar.gz", - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", + "aliases": None, + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "cpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "626614931356.dkr.ecr.af-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-east-1": { + "cpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-northeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-northeast-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-northeast-3": { + "cpu_ecr_uri_1": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "364406365360.dkr.ecr.ap-northeast-3.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-south-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-3": { + "cpu_ecr_uri_1": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "907027046896.dkr.ecr.ap-southeast-3.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ap-southeast-5": { + "cpu_ecr_uri_1": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "550225433462.dkr.ecr.ap-southeast-5.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "ca-central-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "cn-north-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-training:1.10.0-gpu-py38", + }, + "cn-northwest-1": { + "cpu_ecr_uri_1": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "727897471807.dkr.ecr.cn-northwest-1.amazonaws.com.cn/pytorch-training:1.10.0-gpu-py38", + }, + "eu-central-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-central-2": { + "cpu_ecr_uri_1": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "380420809688.dkr.ecr.eu-central-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-north-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-south-1": { + "cpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-west-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-west-2": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "eu-west-3": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "il-central-1": { + "cpu_ecr_uri_1": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "780543022126.dkr.ecr.il-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "me-central-1": { + "cpu_ecr_uri_1": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "914824155844.dkr.ecr.me-central-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "me-south-1": { + "cpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "217643126080.dkr.ecr.me-south-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "sa-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-east-1": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-east-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-gov-east-1": { + "cpu_ecr_uri_1": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "446045086412.dkr.ecr.us-gov-east-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-gov-west-1": { + "cpu_ecr_uri_1": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-west-1": { + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-1.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + }, + "us-west-2": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", + }, + "aliases": None, + "variants": { + "c4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g6e": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r6id": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"regional_properties": {"image_uri": "$alias_ecr_uri_3"}}, + }, + }, + "default_payloads": None, + "gated_bucket": False, + "model_subscription_link": None, + "hosting_additional_data_sources": None, + "hosting_neuron_model_id": None, + "hosting_neuron_model_version": None, + "inference_configs": None, + "inference_config_components": None, + "inference_config_rankings": None, + "training_configs": None, + "training_config_components": None, + "training_config_rankings": None, +} +BASE_HOSTING_ADDITIONAL_DATA_SOURCES = { + "hosting_additional_data_sources": { + "speculative_decoding": [ + { + "channel_name": "speculative_decoding_channel", + "artifact_version": "version", + "s3_data_source": { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "s3://bucket/path1", + "hub_access_config": None, + "model_access_config": None, + }, + } + ], + "scripts": [ + { + "channel_name": "scripts_channel", + "artifact_version": "version", + "s3_data_source": { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "s3://bucket/path1", + "hub_access_config": None, + "model_access_config": None, + }, + } + ], + }, +} + +BASE_HEADER = { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v1.0.0.json", +} + +BASE_MANIFEST = [ + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet" + "-inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "1.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-" + "imagenet-inception-v3-classification-4/specs_v1.0.0.json", + }, + { + "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/pytorch-ic-imagenet-" + "inception-v3-classification-4/specs_v2.0.0.json", + }, + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "3.0.0", + "min_version": "4.49.0", + "spec_key": "community_models_specs/tensorflow-ic-" + "imagenet-inception-v3-classification-4/specs_v3.0.0.json", + }, + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.9.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.9.0.json", + }, + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.13.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json", + }, +] + +BASE_PROPRIETARY_HEADER = { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], +} + +BASE_PROPRIETARY_MANIFEST = [ + { + "model_id": "ai21-summarization", + "version": "1.1.003", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "lighton-mini-instruct40b", + "version": "v1.0", + "min_version": "2.0.0", + "spec_key": "proprietary-models/lighton-mini-instruct40b/proprietary_specs_v1.0.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "ai21-paraphrase", + "version": "1.0.005", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "ai21-paraphrase", + "version": "v1.00-rc2-not-valid-version", + "min_version": "2.0.0", + "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, + { + "model_id": "nc-soft-model-1", + "version": "v3.0-not-valid-version!", + "min_version": "2.0.0", + "spec_key": "proprietary-models/nc-soft-model-1/proprietary_specs_1.0.005.json", + "search_keywords": ["Text2Text", "Generation"], + }, +] + +BASE_PROPRIETARY_SPEC = { + "model_id": "ai21-jurassic-2-light", + "version": "2.0.004", + "min_sdk_version": "2.999.0", + "listing_id": "prodview-roz6zicyvi666", + "product_id": "1bd680a0-f29b-479d-91c3-9899743021cf", + "model_subscription_link": "https://aws.amazon.com/marketplace/ai/procurement?productId=1bd680a0", + "hosting_notebook_key": "pmm-notebooks/pmm-notebook-ai21-jurassic-2-light.ipynb", + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 600, + }, + "default_payloads": { + "Shakespeare": { + "content_type": "application/json", + "prompt_key": "prompt", + "output_keys": {"generated_text": "[0].completions[0].data.text"}, + "body": {"prompt": "To be, or", "maxTokens": 1, "temperature": 0}, + } + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "hosting_model_package_arns": { + "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/j2-light-v2-0-004", + "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/j2-light-v2-0-004", + "us-west-1": "arn:aws:sagemaker:us-west-1:382657785993:model-package/j2-light-v2-0-004", + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/j2-light-v2-0-004", + "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/j2-light-v2-0-004", + "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/j2-light-v2-0-004", + "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/j2-light-v2-0-004", + "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/j2-light-v2-0-004", + "eu-west-3": "arn:aws:sagemaker:eu-west-3:843114510376:model-package/j2-light-v2-0-004", + "eu-north-1": "arn:aws:sagemaker:eu-north-1:136758871317:model-package/j2-light-v2-0-004", + "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/j2-light-v2-0-004", + "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/j2-light-v2-0-004", + "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/j2-light-v2-0-004", + "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/j2-light-v2-0-004", + "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/j2-light-v2-0-004", + "sa-east-1": "arn:aws:sagemaker:sa-east-1:270155090741:model-package/j2-light-v2-0-004", + }, +} + + +INFERENCE_CONFIGS = { + "inference_configs": { + "neuron-inference": { + "benchmark_metrics": { + "ml.inf2.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", + "component_names": ["neuron-inference"], + }, + "neuron-inference-budget": { + "benchmark_metrics": { + "ml.inf2.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", + "component_names": ["neuron-base"], + }, + "gpu-inference-budget": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, - {"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"}, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "text", - "default": "1", - "scope": "container", + "component_names": ["gpu-inference-budget"], + }, + "gpu-inference": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", + "component_names": ["gpu-inference"], + }, + "gpu-inference-model-package": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] }, - ], - }, -} - -BASE_SPEC = { - "model_id": "pytorch-ic-mobilenet-v2", - "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", - "version": "1.0.0", - "min_sdk_version": "2.49.0", - "training_supported": True, - "incremental_training_supported": True, - "gated_bucket": False, - "default_payloads": None, - "hosting_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", - }, - "hosting_instance_type_variants": None, - "training_ecr_specs": { - "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "component_names": ["gpu-inference-model-package"], + }, + "gpu-accelerated": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ] + }, + "component_names": ["gpu-accelerated"], + }, }, - "training_instance_type_variants": None, - "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", - "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", - "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", - "training_prepacked_script_key": None, - "hosting_prepacked_artifact_key": None, - "training_model_package_artifact_uris": None, - "deprecate_warn_message": None, - "deprecated_message": None, - "hosting_model_package_arns": None, - "hosting_eula_key": None, - "model_subscription_link": None, - "hyperparameters": [ - { - "name": "epochs", - "type": "int", - "default": 3, - "min": 1, - "max": 1000, - "scope": "algorithm", + "inference_config_components": { + "neuron-base": { + "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"] }, - { - "name": "adam-learning-rate", - "type": "float", - "default": 0.05, - "min": 1e-08, - "max": 1, - "scope": "algorithm", + "neuron-inference": { + "default_inference_instance_type": "ml.inf2.xlarge", + "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], + "hosting_ecr_specs": { + "framework": "huggingface-llm-neuronx", + "framework_version": "0.0.17", + "py_version": "py310", + }, + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, }, - { - "name": "batch-size", - "type": "int", - "default": 4, - "min": 1, - "max": 1024, - "scope": "algorithm", + "neuron-budget": { + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + } + ], }, - { - "name": "sagemaker_submit_directory", - "type": "text", - "default": "/opt/ml/input/data/code/sourcedir.tar.gz", - "scope": "container", + "gpu-inference": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + } + }, + "variants": { + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, }, - { - "name": "sagemaker_program", - "type": "text", - "default": "transfer_learning.py", - "scope": "container", + "gpu-inference-model-package": { + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" + "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + }, }, - { - "name": "sagemaker_container_log_level", - "type": "text", - "default": "20", - "scope": "container", + "gpu-inference-budget": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference-budget/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting:1.13.1-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, }, - ], - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, + "gpu-accelerated": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + "hosting_additional_data_sources": { + "speculative_decoding": [ + { + "channel_name": "draft_model_name", + "artifact_version": "1.2.1", + "s3_data_source": { + "compression_type": "None", + "model_access_config": {"accept_eula": False}, + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/draft/model/artifact/", + }, + } + ], + }, }, - { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", - "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + }, +} + +TRAINING_CONFIGS = { + "training_configs": { + "neuron-training": { + "benchmark_metrics": { + "ml.tr1n1.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ], + "ml.tr1n1.4xlarge": [ + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ], + }, + "component_names": ["neuron-training"], + "default_inference_config": "neuron-inference", + "default_incremental_training_config": "neuron-training", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, - { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", - "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "neuron-training-budget": { + "benchmark_metrics": { + "ml.tr1n1.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ], + "ml.tr1n1.4xlarge": [ + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ], + }, + "component_names": ["neuron-training-budget"], + "default_inference_config": "neuron-inference-budget", + "default_incremental_training_config": "neuron-training-budget", + "supported_inference_configs": ["neuron-inference", "neuron-inference-budget"], + "supported_incremental_training_configs": ["neuron-training", "neuron-training-budget"], }, - { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", - "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, + "gpu-training": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "200", "unit": "Tokens/S", "concurrency": "1"} + ], + }, + "component_names": ["gpu-training"], + "default_inference_config": "gpu-inference", + "default_incremental_training_config": "gpu-training", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, - { - "name": "ENDPOINT_SERVER_TIMEOUT", - "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, + "gpu-training-budget": { + "benchmark_metrics": { + "ml.p3.2xlarge": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": "1"} + ] + }, + "component_names": ["gpu-training-budget"], + "default_inference_config": "gpu-inference-budget", + "default_incremental_training_config": "gpu-training-budget", + "supported_inference_configs": ["gpu-inference", "gpu-inference-budget"], + "supported_incremental_training_configs": ["gpu-training", "gpu-training-budget"], }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + }, + "training_config_components": { + "neuron-training": { + "default_training_instance_type": "ml.trn1.2xlarge", + "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, }, - { - "name": "SAGEMAKER_ENV", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "gpu-training": { + "default_training_instance_type": "ml.p2.xlarge", + "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-training:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + }, + "neuron-training-budget": { + "default_training_instance_type": "ml.trn1.2xlarge", + "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, }, - { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, - "scope": "container", - "required_for_model_class": True, + "gpu-training-budget": { + "default_training_instance_type": "ml.p2.xlarge", + "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training:1.13.1-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, }, - ], - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "default_inference_instance_type": "ml.p2.xlarge", - "supported_inference_instance_types": [ - "ml.p2.xlarge", - "ml.p3.2xlarge", - "ml.g4dn.xlarge", - "ml.m5.large", - "ml.m5.xlarge", - "ml.c5.xlarge", - "ml.c5.2xlarge", - ], - "default_training_instance_type": "ml.p3.2xlarge", - "supported_training_instance_types": [ - "ml.p3.2xlarge", - "ml.p2.xlarge", - "ml.g4dn.2xlarge", - "ml.m5.xlarge", - "ml.c5.2xlarge", - ], - "hosting_use_script_uri": True, - "usage_info_message": None, - "metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}], - "model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"}, - "deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"}, - "estimator_kwargs": { - "encrypt_inter_container_traffic": True, }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, - "predictor_specs": { - "supported_content_types": ["application/x-image"], - "supported_accept_types": ["application/json;verbose", "application/json"], - "default_content_type": "application/x-image", - "default_accept_type": "application/json", - }, - "inference_volume_size": 123, - "training_volume_size": 456, - "inference_enable_network_isolation": True, - "training_enable_network_isolation": False, - "resource_name_base": "dfsdfsds", - "hosting_resource_requirements": {"num_accelerators": 1, "min_memory_mb": 34360}, - "dynamic_container_deployment_supported": True, - "inference_configs": None, - "inference_config_components": None, - "training_configs": None, - "training_config_components": None, - "inference_config_rankings": None, - "training_config_rankings": None, -} - -BASE_HEADER = { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "1.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet" - "-inception-v3-classification-4/specs_v1.0.0.json", } -BASE_MANIFEST = [ - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "1.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet" - "-inception-v3-classification-4/specs_v1.0.0.json", - }, - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "2.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/tensorflow-ic-imagenet" - "-inception-v3-classification-4/specs_v2.0.0.json", - }, - { - "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", - "version": "1.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-" - "imagenet-inception-v3-classification-4/specs_v1.0.0.json", - }, - { - "model_id": "pytorch-ic-imagenet-inception-v3-classification-4", - "version": "2.0.0", - "min_version": "2.49.0", - "spec_key": "community_models_specs/pytorch-ic-imagenet-" - "inception-v3-classification-4/specs_v2.0.0.json", - }, - { - "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", - "version": "3.0.0", - "min_version": "4.49.0", - "spec_key": "community_models_specs/tensorflow-ic-" - "imagenet-inception-v3-classification-4/specs_v3.0.0.json", - }, -] -BASE_PROPRIETARY_HEADER = { - "model_id": "ai21-summarization", - "version": "1.1.003", - "min_version": "2.0.0", - "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", - "search_keywords": ["Text2Text", "Generation"], +INFERENCE_CONFIG_RANKINGS = { + "inference_config_rankings": { + "overall": { + "description": "Overall rankings of configs", + "rankings": [ + "neuron-inference", + "neuron-inference-budget", + "gpu-inference", + "gpu-inference-budget", + "gpu-accelerated", + ], + }, + "performance": { + "description": "Configs ranked based on performance", + "rankings": [ + "neuron-inference", + "gpu-inference", + "neuron-inference-budget", + "gpu-inference-budget", + ], + }, + "cost": { + "description": "Configs ranked based on cost", + "rankings": [ + "neuron-inference-budget", + "gpu-inference-budget", + "neuron-inference", + "gpu-inference", + ], + }, + } } -BASE_PROPRIETARY_MANIFEST = [ - { - "model_id": "ai21-summarization", - "version": "1.1.003", - "min_version": "2.0.0", - "spec_key": "proprietary-models/ai21-summarization/proprietary_specs_1.1.003.json", - "search_keywords": ["Text2Text", "Generation"], - }, - { - "model_id": "lighton-mini-instruct40b", - "version": "v1.0", - "min_version": "2.0.0", - "spec_key": "proprietary-models/lighton-mini-instruct40b/proprietary_specs_v1.0.json", - "search_keywords": ["Text2Text", "Generation"], - }, - { - "model_id": "ai21-paraphrase", - "version": "1.0.005", - "min_version": "2.0.0", - "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", - "search_keywords": ["Text2Text", "Generation"], - }, - { - "model_id": "ai21-paraphrase", - "version": "v1.00-rc2-not-valid-version", - "min_version": "2.0.0", - "spec_key": "proprietary-models/ai21-paraphrase/proprietary_specs_1.0.005.json", - "search_keywords": ["Text2Text", "Generation"], - }, - { - "model_id": "nc-soft-model-1", - "version": "v3.0-not-valid-version!", - "min_version": "2.0.0", - "spec_key": "proprietary-models/nc-soft-model-1/proprietary_specs_1.0.005.json", - "search_keywords": ["Text2Text", "Generation"], - }, -] - -BASE_PROPRIETARY_SPEC = { - "model_id": "ai21-jurassic-2-light", - "version": "2.0.004", - "min_sdk_version": "2.999.0", - "listing_id": "prodview-roz6zicyvi666", - "product_id": "1bd680a0-f29b-479d-91c3-9899743021cf", - "model_subscription_link": "https://aws.amazon.com/marketplace/ai/procurement?productId=1bd680a0", - "hosting_notebook_key": "pmm-notebooks/pmm-notebook-ai21-jurassic-2-light.ipynb", - "deploy_kwargs": { - "model_data_download_timeout": 3600, - "container_startup_health_check_timeout": 600, - }, - "default_payloads": { - "Shakespeare": { - "content_type": "application/json", - "prompt_key": "prompt", - "output_keys": {"generated_text": "[0].completions[0].data.text"}, - "body": {"prompt": "To be, or", "maxTokens": 1, "temperature": 0}, - } - }, - "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", - "default_accept_type": "application/json", - }, - "default_inference_instance_type": "ml.p4de.24xlarge", - "supported_inference_instance_types": ["ml.p4de.24xlarge"], - "hosting_model_package_arns": { - "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/j2-light-v2-0-004", - "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/j2-light-v2-0-004", - "us-west-1": "arn:aws:sagemaker:us-west-1:382657785993:model-package/j2-light-v2-0-004", - "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/j2-light-v2-0-004", - "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/j2-light-v2-0-004", - "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/j2-light-v2-0-004", - "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/j2-light-v2-0-004", - "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/j2-light-v2-0-004", - "eu-west-3": "arn:aws:sagemaker:eu-west-3:843114510376:model-package/j2-light-v2-0-004", - "eu-north-1": "arn:aws:sagemaker:eu-north-1:136758871317:model-package/j2-light-v2-0-004", - "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/j2-light-v2-0-004", - "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/j2-light-v2-0-004", - "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/j2-light-v2-0-004", - "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/j2-light-v2-0-004", - "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/j2-light-v2-0-004", - "sa-east-1": "arn:aws:sagemaker:sa-east-1:270155090741:model-package/j2-light-v2-0-004", - }, +TRAINING_CONFIG_RANKINGS = { + "training_config_rankings": { + "overall": { + "description": "Overall rankings of configs", + "rankings": [ + "neuron-training", + "neuron-training-budget", + "gpu-training", + "gpu-training-budget", + ], + }, + "performance_training": { + "description": "Configs ranked based on performance", + "rankings": [ + "neuron-training", + "gpu-training", + "neuron-training-budget", + "gpu-training-budget", + ], + "instance_type_overrides": { + "ml.p2.xlarge": [ + "neuron-training", + "neuron-training-budget", + "gpu-training", + "gpu-training-budget", + ] + }, + }, + "cost_training": { + "description": "Configs ranked based on cost", + "rankings": [ + "neuron-training-budget", + "gpu-training-budget", + "neuron-training", + "gpu-training", + ], + }, + } } -INFERENCE_CONFIGS = { - "inference_configs": { - "neuron-inference": { - "benchmark_metrics": { - "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] +DEPLOYMENT_CONFIGS = [ + { + "DeploymentConfigName": "neuron-inference", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } }, - "component_names": ["neuron-inference"], + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, }, - "neuron-inference-budget": { - "benchmark_metrics": { - "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], + }, + { + "DeploymentConfigName": "neuron-inference-budget", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } }, - "component_names": ["neuron-base"], + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, }, - "gpu-inference-budget": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], + }, + { + "DeploymentConfigName": "gpu-inference-budget", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } }, - "component_names": ["gpu-inference-budget"], + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, }, - "gpu-inference": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + "AccelerationConfigs": None, + "BenchmarkMetrics": [ + {"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs", "concurrency": 1} + ], + }, + { + "DeploymentConfigName": "gpu-inference", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface" + "-textgeneration-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } }, - "component_names": ["gpu-inference"], + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.p2.xlarge", + "ComputeResourceRequirements": {"MinMemoryRequiredInMb": None}, + "ModelDataDownloadTimeout": None, + "ContainerStartupHealthCheckTimeout": None, }, + "AccelerationConfigs": None, + "BenchmarkMetrics": [{"name": "Instance Rate", "value": "0.0083000000", "unit": "USD/Hrs"}], }, - "inference_config_components": { - "neuron-base": { - "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"] +] + + +INIT_KWARGS = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu" + "-py310-cu121-ubuntu20.04", + "model_data": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-alpha-us-west-2/huggingface-textgeneration/huggingface-textgeneration" + "-bloom-1b1/artifacts/inference-prepack/v4.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "instance_type": "ml.p2.xlarge", + "env": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "role": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", + "name": "hf-textgeneration-bloom-1b1-2024-04-22-20-23-48-799", + "enable_network_isolation": True, +} + +HUB_MODEL_DOCUMENT_DICTS = { + "huggingface-llm-gemma-2b-instruct": { + "Url": "https://huggingface.co/google/gemma-2b-it", + "MinSdkVersion": "2.189.0", + "TrainingSupported": True, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04", # noqa: E501 + "HostingArtifactS3DataType": "S3Prefix", + "HostingArtifactCompressionType": "None", + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference-prepack/v1.0.0/", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.0", + "HostingUseScriptUri": False, + "HostingEulaUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/fmhMetadata/terms/gemmaTerms.txt", + "TrainingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/llm/prepack/v1.1.1/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptVersion": "1.1.1", + "TrainingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", # noqa: E501 + "TrainingArtifactS3DataType": "S3Prefix", + "TrainingArtifactCompressionType": "None", + "TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "ModelTypes": ["OPEN_WEIGHTS", "PROPRIETARY"], + "Hyperparameters": [ + { + "Name": "peft_type", + "Type": "text", + "Default": "lora", + "Options": ["lora", "None"], + "Scope": "algorithm", + }, + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "chat_dataset", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "epoch", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "lora_r", + "Type": "int", + "Default": 64, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + {"Name": "lora_alpha", "Type": "int", "Default": 16, "Min": 0, "Scope": "algorithm"}, + { + "Name": "lora_dropout", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + {"Name": "bits", "Type": "int", "Default": 4, "Scope": "algorithm"}, + { + "Name": "double_quant", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "quant_Type", + "Type": "text", + "Default": "nf4", + "Options": ["fp4", "nf4"], + "Scope": "algorithm", + }, + { + "Name": "per_device_train_batch_size", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "per_device_eval_batch_size", + "Type": "int", + "Default": 2, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "warmup_ratio", + "Type": "float", + "Default": 0.1, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "train_from_scratch", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "fp16", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "bf16", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "evaluation_strategy", + "Type": "text", + "Default": "steps", + "Options": ["steps", "epoch", "no"], + "Scope": "algorithm", + }, + { + "Name": "eval_steps", + "Type": "int", + "Default": 20, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "gradient_accumulation_steps", + "Type": "int", + "Default": 4, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "logging_steps", + "Type": "int", + "Default": 8, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "weight_decay", + "Type": "float", + "Default": 0.2, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "load_best_model_at_end", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "max_train_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + { + "Name": "max_val_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + { + "Name": "seed", + "Type": "int", + "Default": 10, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "max_input_length", + "Type": "int", + "Default": 1024, + "Min": -1, + "Scope": "algorithm", + }, + { + "Name": "validation_split_ratio", + "Type": "float", + "Default": 0.2, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "train_data_split_seed", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "preprocessing_num_workers", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + {"Name": "max_steps", "Type": "int", "Default": -1, "Scope": "algorithm"}, + { + "Name": "gradient_checkpointing", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "early_stopping_patience", + "Type": "int", + "Default": 3, + "Min": 1, + "Scope": "algorithm", + }, + { + "Name": "early_stopping_threshold", + "Type": "float", + "Default": 0.0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "adam_beta1", + "Type": "float", + "Default": 0.9, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "adam_beta2", + "Type": "float", + "Default": 0.999, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "adam_epsilon", + "Type": "float", + "Default": 1e-08, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "max_grad_norm", + "Type": "float", + "Default": 1.0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "label_smoothing_factor", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "logging_first_step", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "logging_nan_inf_filter", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "save_strategy", + "Type": "text", + "Default": "steps", + "Options": ["no", "epoch", "steps"], + "Scope": "algorithm", + }, + { + "Name": "save_steps", + "Type": "int", + "Default": 500, + "Min": 1, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "save_total_limit", + "Type": "int", + "Default": 1, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "dataloader_drop_last", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "dataloader_num_workers", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "eval_accumulation_steps", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + { + "Name": "auto_find_batch_size", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "lr_scheduler_type", + "Type": "text", + "Default": "constant_with_warmup", + "Options": ["constant_with_warmup", "linear"], + "Scope": "algorithm", + }, + { + "Name": "warmup_steps", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "deepspeed", + "Type": "text", + "Default": "False", + "Options": ["False"], + "Scope": "algorithm", + }, + { + "Name": "sagemaker_submit_directory", + "Type": "text", + "Default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "Scope": "container", + }, + { + "Name": "sagemaker_program", + "Type": "text", + "Default": "transfer_learning.py", + "Scope": "container", + }, + { + "Name": "sagemaker_container_log_level", + "Type": "text", + "Default": "20", + "Scope": "container", + }, + ], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SAGEMAKER_SUBMIT_DIRECTORY", + "Type": "text", + "Default": "/opt/ml/model/code", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "Type": "text", + "Default": "20", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "Type": "text", + "Default": "3600", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "ENDPOINT_SERVER_TIMEOUT", + "Type": "int", + "Default": 3600, + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MODEL_CACHE_ROOT", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SAGEMAKER_ENV", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "HF_MODEL_ID", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MAX_INPUT_LENGTH", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MAX_TOTAL_TOKENS", + "Type": "text", + "Default": "8192", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MAX_BATCH_PREFILL_TOKENS", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SM_NUM_GPUS", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "Type": "int", + "Default": 1, + "Scope": "container", + "RequiredForModelClass": True, + }, + ], + "TrainingMetrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss': ([0-9]+\\.[0-9]+)", + }, # noqa: E501 + ], + "InferenceDependencies": [], + "TrainingDependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], + "DefaultInferenceInstanceType": "ml.g5.xlarge", + "SupportedInferenceInstanceTypes": [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "DefaultTrainingInstanceType": "ml.g5.2xlarge", + "SupportedTrainingInstanceTypes": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", }, - "neuron-inference": { - "default_inference_instance_type": "ml.inf2.xlarge", - "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], - "hosting_ecr_specs": { - "framework": "huggingface-llm-neuronx", - "framework_version": "0.0.17", - "py_version": "py310", + "InferenceVolumeSize": 512, + "TrainingVolumeSize": 512, + "InferenceEnableNetworkIsolation": True, + "TrainingEnableNetworkIsolation": True, + "FineTuningSupported": True, + "ValidationSupported": True, + "DefaultTrainingDatasetUri": "s3://jumpstart-cache-prod-us-west-2/training-datasets/oasst_top/train/", # noqa: E501 + "ResourceNameBase": "hf-llm-gemma-2b-instruct", + "DefaultPayloads": { + "HelloWorld": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "user\nWrite a hello world program\nmodel", # noqa: E501 + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, + }, }, - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - } + "MachineLearningPoem": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "Write me a poem about Machine Learning.", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, }, - "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, }, }, - "neuron-budget": {"inference_environment_variables": {"BUDGET": "1234"}}, - "gpu-inference": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "GatedBucket": True, + "HostingResourceRequirements": {"MinMemoryMb": 8192, "NumAccelerators": 1}, + "HostingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "TrainingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g4dn/v1.0.0/", # noqa: E501 + }, }, - "variants": { - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g5/v1.0.0/", # noqa: E501 + }, + }, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p3dn/v1.0.0/", # noqa: E501 + }, + }, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p4d/v1.0.0/", # noqa: E501 + }, + }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "InferenceConfigRankings": { + "overall": {"Description": "default", "Rankings": ["variant1"]} + }, + "InferenceConfigs": { + "variant1": { + "ComponentNames": ["variant1"], + "BenchmarkMetrics": { + "ml.g5.12xlarge": [ + {"Name": "latency", "Unit": "sec", "Value": "0.19", "Concurrency": "1"}, + ] }, }, }, - "gpu-inference-budget": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference-budget/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + "InferenceConfigComponents": { + "variant1": { + "HostingEcrUri": "123456789012.ecr.us-west-2.amazon.com/repository", + "HostingArtifactUri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-7b/artifacts/variant1/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-monarch-test-hub-bucket/monarch-curated-hub-1714579993.88695/curated_models/meta-textgeneration-llama-2-7b/4.0.0/source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", # noqa: E501 + "InferenceDependencies": [], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, } + ], + "HostingAdditionalDataSources": { + "speculative_decoding": [ + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_1", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/1", + }, + }, + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_2", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/2", + }, + }, + ] }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, }, + "ContextualHelp": { + "HubFormatTrainData": [ + "A train and an optional validation directories. Each directory contains a CSV/JSON/TXT. ", + "- For CSV/JSON files, the text data is used from the column called 'text' or the first column if no column called 'text' is found", # noqa: E501 + "- The number of files under train and validation (if provided) should equal to one, respectively.", + " [Learn how to setup an AWS S3 bucket.](https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingBucket.html)", # noqa: E501 + ], + "HubDefaultTrainData": [ + "Dataset: [SEC](https://www.sec.gov/edgar/searchedgar/companysearch)", + "SEC filing contains regulatory documents that companies and issuers of securities must submit to the Securities and Exchange Commission (SEC) on a regular basis.", # noqa: E501 + "License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode)", + ], + }, + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "EncryptInterContainerTraffic": True, + "DisableOutputCompression": True, + "MaxRuntimeInSeconds": 360000, + "DynamicContainerDeploymentSupported": True, + "TrainingModelPackageArtifactUri": None, + "Dependencies": [], }, -} - -TRAINING_CONFIGS = { - "training_configs": { - "neuron-training": { - "benchmark_metrics": { - "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], - "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + "meta-textgeneration-llama-2-70b": { + "Url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", + "MinSdkVersion": "2.198.0", + "TrainingSupported": True, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # noqa: E501 + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-70b/artifacts/inference/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-70b/artifacts/inference-prepack/v1.0.0/", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.0", + "HostingUseScriptUri": False, + "HostingEulaUri": "s3://jumpstart-cache-prod-us-west-2/fmhMetadata/eula/llamaEula.txt", + "InferenceDependencies": [], + "TrainingDependencies": [ + "accelerate==0.21.0", + "bitsandbytes==0.39.1", + "black==23.7.0", + "brotli==1.0.9", + "datasets==2.14.1", + "fire==0.5.0", + "huggingface-hub==0.20.3", + "inflate64==0.3.1", + "loralib==0.1.1", + "multivolumefile==0.2.3", + "mypy-extensions==1.0.0", + "nvidia-cublas-cu12==12.1.3.1", + "nvidia-cuda-cupti-cu12==12.1.105", + "nvidia-cuda-nvrtc-cu12==12.1.105", + "nvidia-cuda-runtime-cu12==12.1.105", + "nvidia-cudnn-cu12==8.9.2.26", + "nvidia-cufft-cu12==11.0.2.54", + "nvidia-curand-cu12==10.3.2.106", + "nvidia-cusolver-cu12==11.4.5.107", + "nvidia-cusolver-cu12==11.4.5.107", + "nvidia-cusparse-cu12==12.1.0.106", + "nvidia-nccl-cu12==2.19.3", + "nvidia-nvjitlink-cu12==12.3.101", + "nvidia-nvtx-cu12==12.1.105", + "pathspec==0.11.1", + "peft==0.4.0", + "py7zr==0.20.5", + "pybcj==1.0.1", + "pycryptodomex==3.18.0", + "pyppmd==1.0.0", + "pyzstd==0.15.9", + "safetensors==0.3.1", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.4", + "sagemaker_jumpstart_script_utilities==1.1.9", + "scipy==1.11.1", + "termcolor==2.3.0", + "texttable==1.6.7", + "tokenize-rt==5.1.0", + "tokenizers==0.13.3", + "torch>=2.6.0", + "transformers==4.33.3", + "triton==2.2.0", + "typing-extensions==4.8.0", + ], + "Hyperparameters": [ + { + "Name": "epoch", + "Type": "int", + "Default": 5, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", }, - "component_names": ["neuron-training"], - }, - "neuron-training-budget": { - "benchmark_metrics": { - "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], - "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", }, - "component_names": ["neuron-training-budget"], - }, - "gpu-training": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}], + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["gpu-training"], - }, - "gpu-training-budget": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + ], + "TrainingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/transfer_learning/textgeneration/v1.0.11/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.5/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptVersion": "1.0.5", + "TrainingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # TODO: not a training image # noqa: E501 + "TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-training/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + "InferenceEnvironmentVariables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, }, - "component_names": ["gpu-training-budget"], - }, - }, - "training_config_components": { - "neuron-training": { - "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } - }, - "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + ], + "TrainingMetrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", }, + ], + "DefaultInferenceInstanceType": "ml.g5.48xlarge", + "supported_inference_instance_types": ["ml.g5.48xlarge", "ml.p4d.24xlarge"], + "default_training_instance_type": "ml.g5.48xlarge", + "SupportedInferenceInstanceTypes": ["ml.g5.48xlarge", "ml.p4d.24xlarge"], + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "EncryptInterContainerTraffic": True, + "DisableOutputCompression": True, + "MaxRuntimeInSeconds": 360000, + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", }, - "gpu-training": { - "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "InferenceVolumeSize": 256, + "TrainingVolumeSize": 256, + "InferenceEnableNetworkIsolation": True, + "TrainingEnableNetworkIsolation": True, + "DefaultTrainingDatasetUri": "s3://jumpstart-cache-prod-us-west-2/training-datasets/sec_amazon/", # noqa: E501 + "ValidationSupported": True, + "FineTuningSupported": True, + "ResourceNameBase": "meta-textgeneration-llama-2-70b", + "DefaultPayloads": { + "meaningOfLife": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "Body": { + "inputs": "I believe the meaning of life is", + "parameters": { + "max_new_tokens": 64, + "top_p": 0.9, + "temperature": 0.6, + "decoder_input_details": True, + "details": True, + }, }, }, - }, - "neuron-training-budget": { - "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "theoryOfRelativity": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": {"generated_text": "[0].generated_text"}, + "Body": { + "inputs": "Simply put, the theory of relativity states that ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, }, - "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, }, }, - "gpu-training-budget": { - "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "GatedBucket": True, + "HostingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "TrainingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "meta-training/g5/v1.0.0/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + }, }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "meta-training/p4d/v1.0.0/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + }, }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, }, + "HostingArtifactS3DataType": "S3Prefix", + "HostingArtifactCompressionType": "None", + "HostingResourceRequirements": {"MinMemoryMb": 393216, "NumAccelerators": 8}, + "DynamicContainerDeploymentSupported": True, + "TrainingModelPackageArtifactUri": None, + "Task": "text generation", + "DataType": "text", + "Framework": "meta", + "Dependencies": [], }, -} - - -INFERENCE_CONFIG_RANKINGS = { - "inference_config_rankings": { - "overall": { - "description": "Overall rankings of configs", - "rankings": [ - "neuron-inference", - "neuron-inference-budget", - "gpu-inference", - "gpu-inference-budget", - ], - }, - "performance": { - "description": "Configs ranked based on performance", - "rankings": [ - "neuron-inference", - "gpu-inference", - "neuron-inference-budget", - "gpu-inference-budget", - ], - }, - "cost": { - "description": "Configs ranked based on cost", - "rankings": [ - "neuron-inference-budget", - "gpu-inference-budget", - "neuron-inference", - "gpu-inference", - ], + "huggingface-textembedding-bloom-7b1": { + "Url": "https://huggingface.co/bigscience/bloom-7b1", + "MinSdkVersion": "2.144.0", + "TrainingSupported": False, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # noqa: E501 + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/infer-huggingface-textembedding-bloom-7b1.tar.gz", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/textembedding/v1.0.1/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/infer-prepack-huggingface-textembedding-bloom-7b1.tar.gz", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.1", + "InferenceDependencies": [ + "accelerate==0.16.0", + "bitsandbytes==0.37.0", + "filelock==3.9.0", + "huggingface_hub==0.12.0", + "regex==2022.7.9", + "tokenizers==0.13.2", + "transformers==4.26.0", + ], + "TrainingDependencies": [], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + } + ], + "TrainingMetrics": [], + "DefaultInferenceInstanceType": "ml.g5.12xlarge", + "SupportedInferenceInstanceTypes": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.g4dn.12xlarge", + ], + "deploy_kwargs": { + "ModelDataDownloadTimeout": 3600, + "ContainerStartupHealthCheckTimeout": 3600, }, - } -} - -TRAINING_CONFIG_RANKINGS = { - "training_config_rankings": { - "overall": { - "description": "Overall rankings of configs", - "rankings": [ - "neuron-training", - "neuron-training-budget", - "gpu-training", - "gpu-training-budget", - ], + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json", "application/x-text"], + "SupportedAcceptTypes": ["application/json;verbose", "application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", }, - "performance_training": { - "description": "Configs ranked based on performance", - "rankings": [ - "neuron-training", - "gpu-training", - "neuron-training-budget", - "gpu-training-budget", - ], - "instance_type_overrides": { - "ml.p2.xlarge": [ - "neuron-training", - "neuron-training-budget", - "gpu-training", - "gpu-training-budget", - ] + "InferenceVolumeSize": 256, + "InferenceEnableNetworkIsolation": True, + "ValidationSupported": False, + "FineTuningSupported": False, + "ResourceNameBase": "hf-textembedding-bloom-7b1", + "HostingInstanceTypeVariants": { + "Aliases": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", # noqa: E501 + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-gpu-py38", + }, + "Variants": { + "c4": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"properties": {"image_uri": "$alias_ecr_uri_3"}}, }, }, - "cost_training": { - "description": "Configs ranked based on cost", - "rankings": [ - "neuron-training-budget", - "gpu-training-budget", - "neuron-training", - "gpu-training", - ], - }, - } + "TrainingModelPackageArtifactUri": None, + "DynamicContainerDeploymentSupported": False, + "License": "BigScience RAIL", + "Dependencies": [], + }, } diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index ce5f15b287..4a64b413f4 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -46,6 +46,8 @@ from sagemaker.model import Model from sagemaker.predictor import Predictor from tests.unit.sagemaker.jumpstart.utils import ( + get_prototype_manifest, + get_prototype_spec_with_configs, get_special_model_spec, overwrite_dictionary, ) @@ -67,8 +69,12 @@ class EstimatorTest(unittest.TestCase): @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @@ -192,9 +198,14 @@ def test_non_prepacked( ], ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -210,7 +221,10 @@ def test_prepacked( mock_session_estimator: mock.Mock, mock_session_model: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, ): + mock_sagemaker_timestamp.return_value = "8675309" + mock_estimator_deploy.return_value = default_predictor mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -229,32 +243,71 @@ def test_prepacked( mock_estimator_init.assert_called_once_with( instance_type="ml.p3.16xlarge", instance_count=1, - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:1.10.2" - "-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", - model_uri="s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface" + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface" + "-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + model_uri="s3://jumpstart-cache-prod-us-west-2/huggingface-training" + "/train-huggingface" "-text2text-flan-t5-base.tar.gz", - source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/" - "transfer_learning/text2text/prepack/v1.0.1/sourcedir.tar.gz", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs" + "/huggingface/transfer_learning/text2text/prepack/v2.0.0/sourcedir.tar.gz", entry_point="transfer_learning.py", hyperparameters={ "epochs": "1", + "max_steps": "-1", "seed": "42", "batch_size": "64", "learning_rate": "0.0001", + "lr_scheduler_type": "constant_with_warmup", + "warmup_ratio": "0.0", + "warmup_steps": "0", "validation_split_ratio": "0.05", "train_data_split_seed": "0", + "max_train_samples": "-1", + "max_eval_samples": "-1", + "max_input_length": "-1", + "max_output_length": "128", + "pad_to_max_length": "True", + "gradient_accumulation_steps": "1", + "weight_decay": "0.0", + "adam_beta1": "0.9", + "adam_beta2": "0.999", + "adam_epsilon": "1e-08", + "max_grad_norm": "1.0", + "load_best_model_at_end": "True", + "early_stopping_patience": "3", + "early_stopping_threshold": "0.0", + "label_smoothing_factor": "0", + "logging_strategy": "steps", + "logging_first_step": "False", + "logging_steps": "500", + "logging_nan_inf_filter": "True", + "save_strategy": "epoch", + "save_steps": "500", + "save_total_limit": "2", + "dataloader_drop_last": "False", + "dataloader_num_workers": "0", + "evaluation_strategy": "epoch", + "eval_steps": "500", + "eval_accumulation_steps": "None", + "gradient_checkpointing": "True", + "auto_find_batch_size": "False", + "preprocessing_num_workers": "None", + "peft_type": "none", }, metric_definitions=[ {"Name": "huggingface-text2text:eval-loss", "Regex": "'eval_loss': ([0-9\\.]+)"} ], role=execution_role, - encrypt_inter_container_traffic=False, + encrypt_inter_container_traffic=True, sagemaker_session=sagemaker_session, - enable_network_isolation=False, + enable_network_isolation=True, tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "js-trainable-model-prepacked"}, - {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.2.0"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "2.2.3"}, ], + volume_size=512, + max_run=360000, + disable_output_compression=True, ) channels = { @@ -264,37 +317,51 @@ def test_prepacked( estimator.fit(channels) - mock_estimator_fit.assert_called_once_with(inputs=channels, wait=True) + mock_estimator_fit.assert_called_once_with( + inputs=channels, wait=True, job_name="hf-text2text-flan-t5-base-8675309" + ) estimator.deploy() mock_estimator_deploy.assert_called_once_with( - instance_type="ml.g5.xlarge", + instance_type="ml.g5.2xlarge", initial_instance_count=1, - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:" - "1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", env={ "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", "MODEL_CACHE_ROOT": "/opt/ml/model", "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SM_NUM_GPUS": "1", + "MAX_INPUT_LENGTH": "1024", + "MAX_TOTAL_TOKENS": "2048", "SAGEMAKER_MODEL_SERVER_WORKERS": "1", }, predictor_cls=Predictor, + endpoint_name="hf-text2text-flan-t5-base-8675309", role=execution_role, wait=True, use_compiled_model=False, - enable_network_isolation=False, + enable_network_isolation=True, tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "js-trainable-model-prepacked"}, - {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.2.0"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "2.2.3"}, ], + model_data_download_timeout=1200, + container_startup_health_check_timeout=1200, + model_name="hf-text2text-flan-t5-base-8675309", ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -325,34 +392,53 @@ def test_gated_model_s3_uri( mock_session_estimator.return_value = sagemaker_session mock_session_model.return_value = sagemaker_session - with pytest.raises(ValueError) as e: - JumpStartEstimator( - model_id=model_id, - environment={ - "accept_eula": "false", - "what am i": "doing", - "SageMakerGatedModelS3Uri": "none of your business", - }, - ) - assert str(e.value) == ( - "Need to define ‘accept_eula'='true' within Environment. " - "Model 'meta-textgeneration-llama-2-7b-f' requires accepting end-user " - "license agreement (EULA). See " - "https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/fmhMetadata/eula/llamaEula.txt" - " for terms of use." - ) - mock_estimator_init.reset_mock() estimator = JumpStartEstimator(model_id=model_id, environment={"accept_eula": "true"}) mock_estimator_init.assert_called_once_with( - instance_type="ml.p3.2xlarge", + instance_type="ml.g5.12xlarge", instance_count=1, - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-" + "pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" - "meta/transfer_learning/textgeneration/v1.0.0/sourcedir.tar.gz", + "meta/transfer_learning/textgeneration/v1.0.6/sourcedir.tar.gz", entry_point="transfer_learning.py", + hyperparameters={ + "int8_quantization": "False", + "enable_fsdp": "True", + "epoch": "1", + "learning_rate": "0.0001", + "lora_r": "8", + "lora_alpha": "32", + "lora_dropout": "0.05", + "instruction_tuned": "False", + "chat_dataset": "True", + "add_input_output_demarcation_key": "True", + "per_device_train_batch_size": "1", + "per_device_eval_batch_size": "1", + "max_train_samples": "-1", + "max_val_samples": "-1", + "seed": "10", + "max_input_length": "-1", + "validation_split_ratio": "0.2", + "train_data_split_seed": "0", + "preprocessing_num_workers": "None", + }, + metric_definitions=[ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], role=execution_role, sagemaker_session=sagemaker_session, max_run=360000, @@ -360,14 +446,15 @@ def test_gated_model_s3_uri( encrypt_inter_container_traffic=True, environment={ "accept_eula": "true", - "SageMakerGatedModelS3Uri": "s3://jumpstart-cache-alpha-us-west-2/dummy.tar.gz", + "SageMakerGatedModelS3Uri": "s3://sagemaker-repository-pdx/" + "model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", }, tags=[ { "Key": "sagemaker-sdk:jumpstart-model-id", "Value": "js-gated-artifact-trainable-model", }, - {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.0"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.4"}, ], ) @@ -389,7 +476,7 @@ def test_gated_model_s3_uri( initial_instance_count=1, predictor_cls=Predictor, endpoint_name="meta-textgeneration-llama-2-7b-f-8675309", - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118", wait=True, model_data_download_timeout=3600, container_startup_health_check_timeout=3600, @@ -402,7 +489,152 @@ def test_gated_model_s3_uri( "Key": "sagemaker-sdk:jumpstart-model-id", "Value": "js-gated-artifact-trainable-model", }, - {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.0"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.4"}, + ], + ) + + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_gated_model_s3_uri_with_eula_in_fit( + self, + mock_estimator_deploy: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session_estimator: mock.Mock, + mock_session_model: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_timestamp: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + + mock_timestamp.return_value = "8675309" + + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + + model_id, _ = "js-gated-artifact-trainable-model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session_estimator.return_value = sagemaker_session + mock_session_model.return_value = sagemaker_session + + mock_estimator_init.reset_mock() + + estimator = JumpStartEstimator(model_id=model_id) + + mock_estimator_init.assert_called_once_with( + instance_type="ml.g5.12xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-" + "pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "meta/transfer_learning/textgeneration/v1.0.6/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={ + "int8_quantization": "False", + "enable_fsdp": "True", + "epoch": "1", + "learning_rate": "0.0001", + "lora_r": "8", + "lora_alpha": "32", + "lora_dropout": "0.05", + "instruction_tuned": "False", + "chat_dataset": "True", + "add_input_output_demarcation_key": "True", + "per_device_train_batch_size": "1", + "per_device_eval_batch_size": "1", + "max_train_samples": "-1", + "max_val_samples": "-1", + "seed": "10", + "max_input_length": "-1", + "validation_split_ratio": "0.2", + "train_data_split_seed": "0", + "preprocessing_num_workers": "None", + }, + metric_definitions=[ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + role=execution_role, + sagemaker_session=sagemaker_session, + max_run=360000, + enable_network_isolation=True, + encrypt_inter_container_traffic=True, + environment={ + "SageMakerGatedModelS3Uri": "s3://sagemaker-repository-pdx/" + "model-data-model-package_llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302", + }, + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-trainable-model", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.4"}, + ], + ) + + channels = { + "training": f"s3://{get_jumpstart_content_bucket(region)}/" + f"some-training-dataset-doesn't-matter", + } + + estimator.fit(channels, accept_eula=True) + + mock_estimator_fit.assert_called_once_with( + inputs=channels, + wait=True, + job_name="meta-textgeneration-llama-2-7b-f-8675309", + ) + + assert hasattr(estimator, "model_access_config") + assert hasattr(estimator, "hub_access_config") + + assert estimator.model_access_config == {"AcceptEula": True} + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + predictor_cls=Predictor, + endpoint_name="meta-textgeneration-llama-2-7b-f-8675309", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118", + wait=True, + model_data_download_timeout=3600, + container_startup_health_check_timeout=3600, + role=execution_role, + enable_network_isolation=True, + model_name="meta-textgeneration-llama-2-7b-f-8675309", + use_compiled_model=False, + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-trainable-model", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.0.4"}, ], ) @@ -411,8 +643,12 @@ def test_gated_model_s3_uri( ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -559,8 +795,12 @@ def test_gated_model_non_model_package_s3_uri( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -597,7 +837,7 @@ def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( assert ( str(e.value) == "Model package artifact s3 uri for 'js-gated-artifact-trainable-model' " "not supported in eu-north-1. Please try one of the following regions: " - "us-west-2, us-east-1, eu-west-1, ap-southeast-1." + "us-west-2, us-east-2, us-east-1, eu-west-1, ap-southeast-1, ap-southeast-2." ) @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @@ -680,7 +920,6 @@ def test_estimator_use_kwargs(self): "input_mode": "File", "output_path": "Optional[Union[str, PipelineVariable]] = None", "output_kms_key": "Optional[Union[str, PipelineVariable]] = None", - "base_job_name": "Optional[str] = None", "sagemaker_session": DEFAULT_JUMPSTART_SAGEMAKER_SESSION, "hyperparameters": {"hyp1": "val1"}, "tags": [], @@ -751,8 +990,12 @@ def test_estimator_use_kwargs(self): @mock.patch("sagemaker.jumpstart.factory.estimator.environment_variables.retrieve_default") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -971,7 +1214,7 @@ def test_jumpstart_estimator_tags( js_tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "js-trainable-model-prepacked"}, - {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "1.2.0"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "2.2.3"}, ] self.assertEqual( @@ -1009,12 +1252,16 @@ def test_jumpstart_estimator_attach_eula_model( additional_kwargs={ "model_id": "gemma-model", "model_version": "*", + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, "environment": {"accept_eula": "true"}, + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, }, ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1022,15 +1269,17 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.return_value = ( + get_model_info_from_training_job.return_value = ( "js-trainable-model-prepacked", "1.0.0", + None, + None, ) mock_get_model_specs.side_effect = get_special_model_spec @@ -1041,7 +1290,7 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1053,11 +1302,13 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case( additional_kwargs={ "model_id": "js-trainable-model-prepacked", "model_version": "1.0.0", + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, }, ) @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") - @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1065,13 +1316,13 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, - get_model_id_version_from_training_job: mock.Mock, + get_model_info_from_training_job: mock.Mock, mock_attach: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - get_model_id_version_from_training_job.side_effect = ValueError() + get_model_info_from_training_job.side_effect = ValueError() mock_get_model_specs.side_effect = get_special_model_spec @@ -1082,7 +1333,7 @@ def test_jumpstart_estimator_attach_no_model_id_sad_case( training_job_name="some-training-job-name", sagemaker_session=mock_session ) - get_model_id_version_from_training_job.assert_called_once_with( + get_model_info_from_training_job.assert_called_once_with( training_job_name="some-training-job-name", sagemaker_session=mock_session, ) @@ -1095,7 +1346,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): and reach out to JumpStart team.""" init_args_to_skip: Set[str] = set(["kwargs"]) - fit_args_to_skip: Set[str] = set() + fit_args_to_skip: Set[str] = set(["accept_eula"]) deploy_args_to_skip: Set[str] = set(["kwargs"]) parent_class_init = Estimator.__init__ @@ -1109,6 +1360,8 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "region", "tolerate_vulnerable_model", "tolerate_deprecated_model", + "config_name", + "hub_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1118,8 +1371,8 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): js_class_fit = JumpStartEstimator.fit js_class_fit_args = set(signature(js_class_fit).parameters.keys()) - assert js_class_fit_args - parent_class_fit_args == set() - assert parent_class_fit_args - js_class_fit_args == fit_args_to_skip + assert js_class_fit_args - parent_class_fit_args == fit_args_to_skip + assert parent_class_fit_args - js_class_fit_args == set() model_class_init = Model.__init__ model_class_init_args = set(signature(model_class_init).parameters.keys()) @@ -1130,11 +1383,15 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): js_class_deploy = JumpStartEstimator.deploy js_class_deploy_args = set(signature(js_class_deploy).parameters.keys()) - assert js_class_deploy_args - parent_class_deploy_args == model_class_init_args - { + assert js_class_deploy_args - parent_class_deploy_args - { + "inference_config_name" + } == model_class_init_args - { "model_data", + "additional_model_data_sources", "self", "name", "resources", + "model_reference_arn", } assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip @@ -1156,8 +1413,12 @@ def test_validate_model_id_and_get_type( @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -1209,14 +1470,20 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=estimator.sagemaker_session, + config_name=None, + hub_arn=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -1265,8 +1532,12 @@ def test_no_predictor_yes_async_inference_config( @mock.patch("sagemaker.jumpstart.estimator.get_default_predictor") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -1317,8 +1588,12 @@ def test_yes_predictor_returns_unmodified_predictor( @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -1366,13 +1641,20 @@ def test_incremental_training_with_unsupported_model_logs_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + config_name=None, + hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -1417,12 +1699,19 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + config_name=None, + hub_arn=None, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -1480,10 +1769,17 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta estimator.deploy(image_uri="blah") assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p4de.24xlarge" + estimator.deploy(image_uri="blah", instance_type="ml.quantum.large") + assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.quantum.large" + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -1560,10 +1856,11 @@ def test_training_passes_role_to_deploy( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @mock.patch( - "sagemaker.jumpstart.factory.model.DEFAULT_JUMPSTART_SAGEMAKER_SESSION", sagemaker_session + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix", + lambda *largs, **kwargs: sagemaker_session, ) @mock.patch( - "sagemaker.jumpstart.factory.estimator.DEFAULT_JUMPSTART_SAGEMAKER_SESSION", + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix", sagemaker_session, ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1641,7 +1938,9 @@ def test_training_passes_session_to_deploy( @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_init_kwargs") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.JumpStartModelsAccessor.reset_cache") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1681,6 +1980,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -1688,6 +1988,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), ] ) @@ -1709,6 +2010,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -1716,12 +2018,15 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), ] ) @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1742,59 +2047,68 @@ def test_model_artifact_variant_estimator( mock_session.return_value = sagemaker_session # this instance type has a special model artifact - JumpStartEstimator(model_id=model_id, instance_type="ml.p2.xlarge") + JumpStartEstimator(model_id=model_id, instance_type="ml.m5.xlarge") mock_estimator_init.assert_called_once_with( - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", instance_count=1, - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:1.13.1" - "-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - model_uri="s3://jumpstart-cache-prod-us-west-2/hello-mars-1", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.10.0-cpu-py38", + model_uri="s3://jumpstart-cache-prod-us-west-2/hello-world-1", source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" - "pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "pytorch/transfer_learning/ic/prepack/v1.1.0/sourcedir.tar.gz", entry_point="transfer_learning.py", - hyperparameters={"epochs": "3", "adam-learning-rate": "0.05", "batch-size": "4"}, + hyperparameters={ + "train_only_top_layer": "True", + "epochs": "5", + "learning_rate": "0.001", + "batch_size": "4", + "reinitialize_top_layer": "Auto", + }, metric_definitions=[ - {"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"} + {"Name": "pytorch-ic:val-accuracy", "Regex": "val_accuracy: ([0-9\\.]+)"} ], role=execution_role, sagemaker_session=sagemaker_session, - enable_network_isolation=False, + enable_network_isolation=True, encrypt_inter_container_traffic=True, - volume_size=456, + max_run=360000, tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "model-artifact-variant-model"}, - {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, ], ) mock_estimator_init.reset_mock() - JumpStartEstimator(model_id=model_id, instance_type="ml.p99.xlarge") + JumpStartEstimator(model_id=model_id, instance_type="ml.p3.2xlarge") mock_estimator_init.assert_called_once_with( - instance_type="ml.p99.xlarge", + instance_type="ml.p3.2xlarge", instance_count=1, - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", - model_uri="s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", - source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" - "transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38", + model_uri="s3://jumpstart-cache-prod-us-west-2/pytorch-training/" + "v2.0.0/train-pytorch-ic-mobilenet-v2.tar.gz", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/transfer_learning/ic/prepack/v1.1.0/sourcedir.tar.gz", entry_point="transfer_learning.py", - hyperparameters={"epochs": "3", "adam-learning-rate": "0.05", "batch-size": "4"}, + hyperparameters={ + "train_only_top_layer": "True", + "epochs": "5", + "learning_rate": "0.001", + "batch_size": "4", + "reinitialize_top_layer": "Auto", + }, metric_definitions=[ - {"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"} + {"Name": "pytorch-ic:val-accuracy", "Regex": "val_accuracy: ([0-9\\.]+)"} ], role=execution_role, sagemaker_session=sagemaker_session, - enable_network_isolation=False, + enable_network_isolation=True, encrypt_inter_container_traffic=True, - volume_size=456, + max_run=360000, tags=[ - { - "Key": "sagemaker-sdk:jumpstart-model-id", - "Value": "model-artifact-variant-model", - }, - {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_ID, "Value": "model-artifact-variant-model"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, ], ) @@ -1848,6 +2162,321 @@ def test_jumpstart_estimator_session( assert len(s3_clients) == 1 assert list(s3_clients)[0] == session.s3_client + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_initialization_with_config_name( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + + mock_sagemaker_timestamp.return_value = "8675309" + + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator( + model_id=model_id, + config_name="gpu-training", + ) + + mock_estimator_init.assert_called_with( + instance_type="ml.p2.xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-training:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", + model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/" + "meta-textgeneration-llama-2-7b/gpu-training/model/", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/transfer_learning/ic/prepack/v1.1.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={ + "train_only_top_layer": "True", + "epochs": "5", + "learning_rate": "0.001", + "batch_size": "4", + "reinitialize_top_layer": "Auto", + }, + metric_definitions=[ + {"Name": "pytorch-ic:val-accuracy", "Regex": "val_accuracy: ([0-9\\.]+)"} + ], + role="fake role! do not use!", + max_run=360000, + sagemaker_session=estimator.sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training"}, + ], + enable_network_isolation=True, + encrypt_inter_container_traffic=True, + ) + + estimator.fit() + + mock_estimator_fit.assert_called_once_with(wait=True, job_name="pt-ic-mobilenet-v2-8675309") + + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_set_config_name( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + + mock_sagemaker_timestamp.return_value = "8675309" + + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + estimator.set_training_config(config_name="gpu-training-budget") + + mock_estimator_init.assert_called_with( + instance_type="ml.p2.xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training:1.13.1-py310-sdk2.14.1-ubuntu20.04", + model_uri="s3://jumpstart-cache-prod-us-west-2/artifacts/meta-textgeneration-llama-2-7b/" + "gpu-training-budget/model/", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/transfer_learning/ic/prepack/v1.1.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={ + "train_only_top_layer": "True", + "epochs": "5", + "learning_rate": "0.001", + "batch_size": "4", + "reinitialize_top_layer": "Auto", + }, + metric_definitions=[ + {"Name": "pytorch-ic:val-accuracy", "Regex": "val_accuracy: ([0-9\\.]+)"} + ], + role="fake role! do not use!", + max_run=360000, + sagemaker_session=estimator.sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "gpu-training-budget"}, + ], + enable_network_isolation=True, + encrypt_inter_container_traffic=True, + ) + + estimator.fit() + + mock_estimator_fit.assert_called_once_with(wait=True, job_name="pt-ic-mobilenet-v2-8675309") + + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_default_inference_config( + self, + mock_estimator_fit: mock.Mock, + mock_estimator_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + mock_sagemaker_timestamp.return_value = "8675309" + mock_estimator_deploy.return_value = default_predictor + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + assert estimator.config_name == "gpu-training" + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.m5.large", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + env={ + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + predictor_cls=Predictor, + wait=True, + role="fake role! do not use!", + use_compiled_model=False, + enable_network_isolation=True, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference"}, + ], + model_name="pt-ic-mobilenet-v2-8675309", + endpoint_name="pt-ic-mobilenet-v2-8675309", + ) + + @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") + @mock.patch("sagemaker.jumpstart.estimator.get_model_info_from_training_job") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_incremental_training_config( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_model_info_from_training_job: mock.Mock, + mock_attach: mock.Mock, + ): + mock_get_model_info_from_training_job.return_value = ( + "pytorch-ic-mobilenet-v2", + "1.0.0", + None, + "gpu-training-budget", + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training") + + assert estimator.config_name == "gpu-training" + + JumpStartEstimator.attach( + training_job_name="some-training-job-name", sagemaker_session=mock_session + ) + + mock_attach.assert_called_once_with( + training_job_name="some-training-job-name", + sagemaker_session=mock_session, + model_channel_name="model", + additional_kwargs={ + "model_id": "pytorch-ic-mobilenet-v2", + "model_version": "1.0.0", + "tolerate_vulnerable_model": True, + "tolerate_deprecated_model": True, + "config_name": "gpu-training-budget", + }, + ) + + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_deploy_with_config( + self, + mock_estimator_init: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + + mock_sagemaker_timestamp.return_value = "8675309" + mock_estimator_deploy.return_value = default_predictor + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_estimator_fit.return_value = default_predictor + + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_session.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id, config_name="gpu-training-budget") + + assert estimator.config_name == "gpu-training-budget" + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.m5.large", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + env={ + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + predictor_cls=Predictor, + wait=True, + role="fake role! do not use!", + use_compiled_model=False, + enable_network_isolation=True, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference-budget"}, + ], + model_name="pt-ic-mobilenet-v2-8675309", + endpoint_name="pt-ic-mobilenet-v2-8675309", + ) + def test_jumpstart_estimator_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py index 073921d5ba..39eca166ee 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py @@ -123,16 +123,16 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_model_init_kwargs.return_value = {} - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), config_role) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), config_role) assert "enable_network_isolation" not in mock_estimator_init.call_args[1] assert "encrypt_inter_container_traffic" not in mock_estimator_init.call_args[1] estimator.deploy() - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), config_inference_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), config_inference_role) assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] @@ -181,13 +181,13 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), config_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), config_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), config_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), config_intercontainer_encryption, ) @@ -200,11 +200,11 @@ def test_without_arg_overwrites_with_kwarg_collisions_with_config( estimator.deploy() - self.assertEquals(mock_get_sagemaker_config_value.call_count, 6) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 6) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), config_inference_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), config_inference_role) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), config_inference_enable_network_isolation, ) @@ -257,13 +257,13 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( encrypt_inter_container_traffic=override_encrypt_inter_container_traffic, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), override_encrypt_inter_container_traffic, ) @@ -280,13 +280,13 @@ def test_with_arg_overwrites_with_kwarg_collisions_with_config( enable_network_isolation=override_inference_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("role"), mock_inference_override_role ) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), override_inference_enable_network_isolation, ) @@ -336,13 +336,13 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( encrypt_inter_container_traffic=override_encrypt_inter_container_traffic, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), override_encrypt_inter_container_traffic, ) @@ -355,13 +355,13 @@ def test_with_arg_overwrites_without_kwarg_collisions_with_config( enable_network_isolation=override_inference_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("role"), mock_inference_override_role ) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), override_inference_enable_network_isolation, ) @@ -412,8 +412,8 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), execution_role) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_estimator_init.call_args[1] assert "encrypt_inter_container_traffic" not in mock_estimator_init.call_args[1] @@ -421,9 +421,9 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_model_init_kwargs.return_value = {} - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), execution_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_estimator_deploy.call_args[1] @@ -475,13 +475,13 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), execution_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), execution_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), metadata_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), metadata_intercontainer_encryption, ) @@ -492,11 +492,11 @@ def test_without_arg_overwrites_with_kwarg_collisions_without_config( estimator.deploy() - self.assertEquals(mock_get_sagemaker_config_value.call_count, 6) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 6) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), execution_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), execution_role) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), metadata_inference_enable_network_isolation, ) @@ -548,13 +548,13 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( encrypt_inter_container_traffic=override_encrypt_inter_container_traffic, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), override_encrypt_inter_container_traffic, ) @@ -568,11 +568,11 @@ def test_with_arg_overwrites_with_kwarg_collisions_without_config( enable_network_isolation=override_inference_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), override_inference_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), override_inference_role) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), override_inference_enable_network_isolation, ) @@ -618,13 +618,13 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( enable_network_isolation=override_enable_network_isolation, encrypt_inter_container_traffic=override_encrypt_inter_container_traffic, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_estimator_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_estimator_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_estimator_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) - self.assertEquals( + self.assertEqual( mock_estimator_init.call_args[1].get("encrypt_inter_container_traffic"), override_encrypt_inter_container_traffic, ) @@ -634,11 +634,11 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( enable_network_isolation=override_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 3) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 3) - self.assertEquals(mock_estimator_deploy.call_args[1].get("role"), override_inference_role) + self.assertEqual(mock_estimator_deploy.call_args[1].get("role"), override_inference_role) - self.assertEquals( + self.assertEqual( mock_estimator_deploy.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) diff --git a/tests/unit/sagemaker/jumpstart/factory/__init__.py b/tests/unit/sagemaker/jumpstart/factory/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/factory/test_estimator.py b/tests/unit/sagemaker/jumpstart/factory/test_estimator.py new file mode 100644 index 0000000000..fd59961f09 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/factory/test_estimator.py @@ -0,0 +1,162 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from unittest.mock import patch +from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME +from sagemaker.jumpstart.factory.estimator import ( + _add_model_uri_to_kwargs, + get_model_info_default_kwargs, +) +from sagemaker.jumpstart.types import JumpStartEstimatorInitKwargs +from sagemaker.jumpstart.enums import JumpStartScriptScope + + +class TestAddModelUriToKwargs: + @pytest.fixture + def mock_kwargs(self): + return JumpStartEstimatorInitKwargs( + model_id="test-model", + model_version="1.0.0", + instance_type="ml.m5.large", + model_uri=None, + ) + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=True, + ) + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") + def test_add_model_uri_to_kwargs_default_uri( + self, mock_retrieve, mock_supports_training, mock_kwargs + ): + """Test adding default model URI when none is provided.""" + default_uri = "s3://jumpstart-models/training/test-model/1.0.0" + mock_retrieve.return_value = default_uri + + result = _add_model_uri_to_kwargs(mock_kwargs) + + mock_supports_training.assert_called_once() + mock_retrieve.assert_called_once_with( + model_scope=JumpStartScriptScope.TRAINING, + instance_type=mock_kwargs.instance_type, + **get_model_info_default_kwargs(mock_kwargs), + ) + assert result.model_uri == default_uri + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=True, + ) + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_incremental_training", + return_value=True, + ) + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") + def test_add_model_uri_to_kwargs_custom_uri_with_incremental( + self, mock_retrieve, mock_supports_incremental, mock_supports_training, mock_kwargs + ): + """Test using custom model URI with incremental training support.""" + default_uri = "s3://jumpstart-models/training/test-model/1.0.0" + custom_uri = "s3://custom-bucket/my-model" + mock_retrieve.return_value = default_uri + mock_kwargs.model_uri = custom_uri + + result = _add_model_uri_to_kwargs(mock_kwargs) + + mock_supports_training.assert_called_once() + mock_supports_incremental.assert_called_once() + assert result.model_uri == custom_uri + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=True, + ) + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_incremental_training", + return_value=False, + ) + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") + @patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") + def test_add_model_uri_to_kwargs_custom_uri_without_incremental( + self, + mock_warning, + mock_retrieve, + mock_supports_incremental, + mock_supports_training, + mock_kwargs, + ): + """Test using custom model URI without incremental training support logs warning.""" + default_uri = "s3://jumpstart-models/training/test-model/1.0.0" + custom_uri = "s3://custom-bucket/my-model" + mock_retrieve.return_value = default_uri + mock_kwargs.model_uri = custom_uri + + result = _add_model_uri_to_kwargs(mock_kwargs) + + mock_supports_training.assert_called_once() + mock_supports_incremental.assert_called_once() + mock_warning.assert_called_once() + assert "does not support incremental training" in mock_warning.call_args[0][0] + assert result.model_uri == custom_uri + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=False, + ) + def test_add_model_uri_to_kwargs_no_training_support(self, mock_supports_training, mock_kwargs): + """Test when model doesn't support training model URI.""" + result = _add_model_uri_to_kwargs(mock_kwargs) + + mock_supports_training.assert_called_once() + assert result.model_uri is None + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=False, + ) + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") + def test_add_model_uri_to_kwargs_private_hub( + self, mock_retrieve, mock_supports_training, mock_kwargs + ): + """Test when model is from a private hub.""" + default_uri = "s3://jumpstart-models/training/test-model/1.0.0" + mock_retrieve.return_value = default_uri + mock_kwargs.hub_arn = "arn:aws:sagemaker:us-west-2:123456789012:hub/private-hub" + + result = _add_model_uri_to_kwargs(mock_kwargs) + + # Should not check if model supports training model URI for private hub + mock_supports_training.assert_not_called() + mock_retrieve.assert_called_once() + assert result.model_uri == default_uri + + @patch( + "sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri", + return_value=False, + ) + @patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve") + def test_add_model_uri_to_kwargs_public_hub( + self, mock_retrieve, mock_supports_training, mock_kwargs + ): + """Test when model is from the public hub.""" + mock_kwargs.hub_arn = ( + f"arn:aws:sagemaker:us-west-2:123456789012:hub/{JUMPSTART_MODEL_HUB_NAME}" + ) + + result = _add_model_uri_to_kwargs(mock_kwargs) + + # Should check if model supports training model URI for public hub + mock_supports_training.assert_called_once() + mock_retrieve.assert_not_called() + assert result.model_uri is None diff --git a/tests/unit/sagemaker/jumpstart/hub/__init__.py b/tests/unit/sagemaker/jumpstart/hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py new file mode 100644 index 0000000000..29efb6b31f --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -0,0 +1,255 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from datetime import datetime +from unittest.mock import patch, MagicMock +import pytest +from mock import Mock +from sagemaker.jumpstart.hub.hub import Hub + + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + +MODULE_PATH = "sagemaker.jumpstart.hub.hub.Hub" + +FAKE_TIME = datetime(1997, 8, 14, 00, 00, 00) + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session") + sagemaker_session_mock = Mock( + name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION + ) + sagemaker_session_mock._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) + sagemaker_session_mock.describe_hub.return_value = { + "S3StorageConfig": {"S3OutputPath": "s3://mock-bucket-123"} + } + sagemaker_session_mock.account_id.return_value = ACCOUNT_ID + return sagemaker_session_mock + + +@pytest.fixture +def mock_instance(sagemaker_session): + mock_instance = MagicMock() + mock_instance.hub_name = "test-hub" + mock_instance._sagemaker_session = sagemaker_session + return mock_instance + + +def test_instantiates(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + assert hub.hub_name == HUB_NAME + assert hub.region == "us-east-1" + assert hub._sagemaker_session == sagemaker_session + + +@pytest.mark.parametrize( + ("hub_name,hub_description,,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +def test_create_with_no_bucket_name( + sagemaker_session, + hub_name, + hub_description, + hub_display_name, + hub_search_keywords, + tags, +): + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + + +@pytest.mark.parametrize( + ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", "mock-bucket-123", None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + "mock-bucket-123", + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +@patch("sagemaker.jumpstart.hub.hub.datetime") +def test_create_with_bucket_name( + mock_datetime, + sagemaker_session, + hub_name, + hub_description, + hub_bucket_name, + hub_display_name, + hub_search_keywords, + tags, +): + mock_datetime.now.return_value = FAKE_TIME + + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "s3_storage_config": { + "S3OutputPath": f"s3://mock-bucket-123/{hub_name}-{FAKE_TIME.timestamp()}" + }, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + + +@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") +def test_describe_model_success(mock_describe_hub_content_response, sagemaker_session): + mock_describe_hub_content_response.return_value = Mock() + mock_list_hub_content_versions = sagemaker_session.list_hub_content_versions + mock_list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0"}, + {"HubContentVersion": "2.0"}, + {"HubContentVersion": "3.0"}, + ] + } + + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + + with patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") as mock_get_hub_model_version: + mock_get_hub_model_version.return_value = "3.0" + + hub.describe_model("test-model") + + mock_list_hub_content_versions.assert_called_with( + hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="ModelReference" + ) + sagemaker_session.describe_hub_content.assert_called_with( + hub_name=HUB_NAME, + hub_content_name="test-model", + hub_content_version="3.0", + hub_content_type="ModelReference", + ) + + +@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") +def test_describe_model_one_thrown_error(mock_describe_hub_content_response, sagemaker_session): + mock_describe_hub_content_response.return_value = Mock() + mock_list_hub_content_versions = sagemaker_session.list_hub_content_versions + mock_list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0"}, + {"HubContentVersion": "2.0"}, + {"HubContentVersion": "3.0"}, + ] + } + mock_describe_hub_content = sagemaker_session.describe_hub_content + mock_describe_hub_content.side_effect = [ + Exception("Some exception"), + {"HubContentName": "test-model", "HubContentVersion": "3.0"}, + ] + + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + + with patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") as mock_get_hub_model_version: + mock_get_hub_model_version.return_value = "3.0" + + hub.describe_model("test-model") + + mock_describe_hub_content.asssert_called_times(2) + mock_describe_hub_content.assert_called_with( + hub_name=HUB_NAME, + hub_content_name="test-model", + hub_content_version="3.0", + hub_content_type="Model", + ) + + +def test_create_hub_content_reference(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + model_name = "mock-model-one-huggingface" + min_version = "1.1.1" + public_model_arn = ( + f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/{model_name}" + ) + create_hub_content_reference = { + "HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{HUB_NAME}", + "HubContentReferenceArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/{HUB_NAME}/ModelRef/{model_name}", # noqa: E501 + } + sagemaker_session.create_hub_content_reference = Mock(return_value=create_hub_content_reference) + + request = { + "hub_name": HUB_NAME, + "source_hub_content_arn": public_model_arn, + "hub_content_name": model_name, + "min_version": min_version, + } + + response = hub.create_model_reference( + model_arn=public_model_arn, model_name=model_name, min_version=min_version + ) + sagemaker_session.create_hub_content_reference.assert_called_with(**request) + + assert response == { + "HubArn": "arn:aws:sagemaker:us-east-1:123456789123:hub/mock-hub-name", + "HubContentReferenceArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/mock-hub-name/ModelRef/mock-model-one-huggingface", # noqa: E501 + } + + +def test_delete_hub_content_reference(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + model_name = "mock-model-one-huggingface" + + hub.delete_model_reference(model_name) + sagemaker_session.delete_hub_content_reference.assert_called_with( + hub_name=HUB_NAME, + hub_content_type="ModelReference", + hub_content_name="mock-model-one-huggingface", + ) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py new file mode 100644 index 0000000000..ebd90d98d2 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py @@ -0,0 +1,1046 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import numpy as np +from sagemaker.jumpstart.types import ( + JumpStartConfigComponent, + JumpStartConfigRanking, + JumpStartHyperparameter, + JumpStartInstanceTypeVariants, + JumpStartEnvironmentVariable, + JumpStartMetadataConfig, + JumpStartMetadataConfigs, + JumpStartPredictorSpecs, + JumpStartSerializablePayload, +) +from sagemaker.jumpstart.hub.interfaces import HubModelDocument +from tests.unit.sagemaker.jumpstart.constants import ( + SPECIAL_MODEL_SPECS_DICT, + HUB_MODEL_DOCUMENT_DICTS, +) + +gemma_model_spec = SPECIAL_MODEL_SPECS_DICT["gemma-model-2b-v1_1_0"] + + +def test_hub_content_document_from_json_obj(): + region = "us-west-2" + json_obj = HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"] + gemma_model_document = HubModelDocument(json_obj=json_obj, region=region) + assert gemma_model_document.url == "https://huggingface.co/google/gemma-2b-it" + assert gemma_model_document.min_sdk_version == "2.189.0" + assert gemma_model_document.training_supported is True + assert gemma_model_document.incremental_training_supported is False + assert ( + gemma_model_document.hosting_ecr_uri + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:" + "2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + with pytest.raises(AttributeError) as excinfo: + gemma_model_document.hosting_ecr_specs + assert str(excinfo.value) == "'HubModelDocument' object has no attribute 'hosting_ecr_specs'" + assert gemma_model_document.hosting_artifact_s3_data_type == "S3Prefix" + assert gemma_model_document.hosting_artifact_compression_type == "None" + assert ( + gemma_model_document.hosting_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct" + "/artifacts/inference/v1.0.0/" + ) + assert ( + gemma_model_document.hosting_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/" + "llm/v1.0.1/sourcedir.tar.gz" + ) + assert gemma_model_document.inference_dependencies == [] + assert gemma_model_document.training_dependencies == [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ] + assert ( + gemma_model_document.hosting_prepacked_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/" + "artifacts/inference-prepack/v1.0.0/" + ) + assert gemma_model_document.hosting_prepacked_artifact_version == "1.0.0" + assert gemma_model_document.hosting_use_script_uri is False + assert ( + gemma_model_document.hosting_eula_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/fmhMetadata/terms/gemmaTerms.txt" + ) + assert ( + gemma_model_document.training_ecr_uri + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers" + "4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + with pytest.raises(AttributeError) as excinfo: + gemma_model_document.training_ecr_specs + assert str(excinfo.value) == "'HubModelDocument' object has no attribute 'training_ecr_specs'" + assert ( + gemma_model_document.training_prepacked_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/" + "llm/prepack/v1.1.1/sourcedir.tar.gz" + ) + assert gemma_model_document.training_prepacked_script_version == "1.1.1" + assert ( + gemma_model_document.training_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/" + "llm/v1.1.1/sourcedir.tar.gz" + ) + assert gemma_model_document.training_artifact_s3_data_type == "S3Prefix" + assert gemma_model_document.training_artifact_compression_type == "None" + assert ( + gemma_model_document.training_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct" + ".tar.gz" + ) + assert gemma_model_document.hyperparameters == [ + JumpStartHyperparameter( + { + "Name": "peft_type", + "Type": "text", + "Default": "lora", + "Options": ["lora", "None"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "chat_dataset", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "epoch", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lora_r", + "Type": "int", + "Default": 64, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "lora_alpha", "Type": "int", "Default": 16, "Min": 0, "Scope": "algorithm"}, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lora_dropout", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "bits", "Type": "int", "Default": 4, "Scope": "algorithm"}, is_hub_content=True + ), + JumpStartHyperparameter( + { + "Name": "double_quant", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "quant_Type", + "Type": "text", + "Default": "nf4", + "Options": ["fp4", "nf4"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "per_device_train_batch_size", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "per_device_eval_batch_size", + "Type": "int", + "Default": 2, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "warmup_ratio", + "Type": "float", + "Default": 0.1, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "train_from_scratch", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "fp16", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "bf16", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "evaluation_strategy", + "Type": "text", + "Default": "steps", + "Options": ["steps", "epoch", "no"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "eval_steps", + "Type": "int", + "Default": 20, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "gradient_accumulation_steps", + "Type": "int", + "Default": 4, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_steps", + "Type": "int", + "Default": 8, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "weight_decay", + "Type": "float", + "Default": 0.2, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "load_best_model_at_end", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_train_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_val_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "seed", + "Type": "int", + "Default": 10, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_input_length", + "Type": "int", + "Default": 1024, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "validation_split_ratio", + "Type": "float", + "Default": 0.2, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "train_data_split_seed", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "preprocessing_num_workers", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "max_steps", "Type": "int", "Default": -1, "Scope": "algorithm"}, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "gradient_checkpointing", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "early_stopping_patience", + "Type": "int", + "Default": 3, + "Min": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "early_stopping_threshold", + "Type": "float", + "Default": 0.0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_beta1", + "Type": "float", + "Default": 0.9, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_beta2", + "Type": "float", + "Default": 0.999, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_epsilon", + "Type": "float", + "Default": 1e-08, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_grad_norm", + "Type": "float", + "Default": 1.0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "label_smoothing_factor", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_first_step", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_nan_inf_filter", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_strategy", + "Type": "text", + "Default": "steps", + "Options": ["no", "epoch", "steps"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_steps", + "Type": "int", + "Default": 500, + "Min": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_total_limit", + "Type": "int", + "Default": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "dataloader_drop_last", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "dataloader_num_workers", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "eval_accumulation_steps", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "auto_find_batch_size", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lr_scheduler_type", + "Type": "text", + "Default": "constant_with_warmup", + "Options": ["constant_with_warmup", "linear"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "warmup_steps", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "deepspeed", + "Type": "text", + "Default": "False", + "Options": ["False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_submit_directory", + "Type": "text", + "Default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "Scope": "container", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_program", + "Type": "text", + "Default": "transfer_learning.py", + "Scope": "container", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_container_log_level", + "Type": "text", + "Default": "20", + "Scope": "container", + }, + is_hub_content=True, + ), + ] + assert gemma_model_document.inference_environment_variables == [ + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_SUBMIT_DIRECTORY", + "Type": "text", + "Default": "/opt/ml/model/code", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "Type": "text", + "Default": "20", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "Type": "text", + "Default": "3600", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "ENDPOINT_SERVER_TIMEOUT", + "Type": "int", + "Default": 3600, + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MODEL_CACHE_ROOT", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_ENV", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "HF_MODEL_ID", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_INPUT_LENGTH", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_TOTAL_TOKENS", + "Type": "text", + "Default": "8192", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_BATCH_PREFILL_TOKENS", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SM_NUM_GPUS", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "Type": "int", + "Default": 1, + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + ] + assert gemma_model_document.training_metrics == [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss': ([0-9]+\\.[0-9]+)", + }, + ] + assert gemma_model_document.default_inference_instance_type == "ml.g5.xlarge" + assert gemma_model_document.supported_inference_instance_types == [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ] + assert gemma_model_document.default_training_instance_type == "ml.g5.2xlarge" + assert np.array_equal( + gemma_model_document.supported_training_instance_types, + [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + ) + assert gemma_model_document.sage_maker_sdk_predictor_specifications == JumpStartPredictorSpecs( + { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + is_hub_content=True, + ) + assert gemma_model_document.inference_volume_size == 512 + assert gemma_model_document.training_volume_size == 512 + assert gemma_model_document.inference_enable_network_isolation is True + assert gemma_model_document.training_enable_network_isolation is True + assert gemma_model_document.fine_tuning_supported is True + assert gemma_model_document.validation_supported is True + assert ( + gemma_model_document.default_training_dataset_uri + == "s3://jumpstart-cache-prod-us-west-2/training-datasets/oasst_top/train/" + ) + assert gemma_model_document.resource_name_base == "hf-llm-gemma-2b-instruct" + assert gemma_model_document.default_payloads == { + "HelloWorld": JumpStartSerializablePayload( + { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "user\nWrite a hello world program" + "\nmodel", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, + }, + }, + is_hub_content=True, + ), + "MachineLearningPoem": JumpStartSerializablePayload( + { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "Write me a poem about Machine Learning.", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, + }, + }, + is_hub_content=True, + ), + } + assert gemma_model_document.gated_bucket is True + assert gemma_model_document.hosting_resource_requirements == { + "MinMemoryMb": 8192, + "NumAccelerators": 1, + } + assert gemma_model_document.hosting_instance_type_variants == JumpStartInstanceTypeVariants( + { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch" + "-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + is_hub_content=True, + ) + assert gemma_model_document.training_instance_type_variants == JumpStartInstanceTypeVariants( + { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-" + "training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "Variants": { + "g4dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g4dn/v1.0.0/", # noqa: E501 + }, + }, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/g5/v1.0.0/", # noqa: E501 + }, + }, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p3dn/v1.0.0/", # noqa: E501 + }, + }, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "training_artifact_uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/p4d/v1.0.0/", # noqa: E501 + }, + }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + is_hub_content=True, + ) + assert gemma_model_document.contextual_help == { + "HubFormatTrainData": [ + "A train and an optional validation directories. Each directory contains a CSV/JSON/TXT. ", + "- For CSV/JSON files, the text data is used from the column called 'text' or the " + "first column if no column called 'text' is found", + "- The number of files under train and validation (if provided) should equal to one," + " respectively.", + " [Learn how to setup an AWS S3 bucket.]" + "(https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingBucket.html)", + ], + "HubDefaultTrainData": [ + "Dataset: [SEC](https://www.sec.gov/edgar/searchedgar/companysearch)", + "SEC filing contains regulatory documents that companies and issuers of securities must " + "submit to the Securities and Exchange Commission (SEC) on a regular basis.", + "License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode)", + ], + } + assert gemma_model_document.model_data_download_timeout == 1200 + assert gemma_model_document.container_startup_health_check_timeout == 1200 + assert gemma_model_document.encrypt_inter_container_traffic is True + assert gemma_model_document.disable_output_compression is True + assert gemma_model_document.max_runtime_in_seconds == 360000 + assert gemma_model_document.dynamic_container_deployment_supported is True + assert gemma_model_document.training_model_package_artifact_uri is None + assert gemma_model_document.dependencies == [] + + inference_config_rankings = { + "overall": JumpStartConfigRanking( + {"Description": "default", "Rankings": ["variant1"]}, is_hub_content=True + ) + } + + inference_config_components = { + "variant1": JumpStartConfigComponent( + "variant1", + { + "HostingEcrUri": "123456789012.ecr.us-west-2.amazon.com/repository", + "HostingArtifactUri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-7b/artifacts/variant1/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-monarch-test-hub-bucket/monarch-curated-hub-1714579993.88695/curated_models/meta-textgeneration-llama-2-7b/4.0.0/source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", # noqa: E501 + "InferenceDependencies": [], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + } + ], + "HostingAdditionalDataSources": { + "speculative_decoding": [ + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_1", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/1", + }, + }, + { + "ArtifactVersion": 1, + "ChannelName": "speculative_decoding_channel_2", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/path/2", + }, + }, + ] + }, + }, + is_hub_content=True, + ) + } + + inference_configs_dict = { + "variant1": JumpStartMetadataConfig( + "variant1", + json_obj["InferenceConfigs"]["variant1"], + json_obj, + inference_config_components, + is_hub_content=True, + ) + } + + inference_configs = JumpStartMetadataConfigs(inference_configs_dict, inference_config_rankings) + + assert gemma_model_document.inference_config_rankings == inference_config_rankings + assert gemma_model_document.inference_config_components == inference_config_components + assert gemma_model_document.inference_configs == inference_configs diff --git a/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py b/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py new file mode 100644 index 0000000000..49d97d177d --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_marketplace_hub_content.py @@ -0,0 +1,132 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from unittest.mock import patch, MagicMock +from mock import Mock +from sagemaker.jumpstart.hub import utils as hub_utils +from sagemaker.jumpstart.enums import JumpStartModelType +from sagemaker.jumpstart.utils import _validate_hub_service_model_id_and_get_type + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + +MOCK_MODEL_ID = "test-model-id" + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session") + sagemaker_session_mock = Mock( + name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION + ) + sagemaker_session_mock._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) + sagemaker_session_mock.account_id.return_value = ACCOUNT_ID + return sagemaker_session_mock + + +@pytest.mark.parametrize( + "input_version, expected_version, expected_exception, expected_message", + [ + ("1.0.0", "1.0.0", None, None), + ("*", "3.2.0", None, None), + (None, "3.2.0", None, None), + ("1.*", "1.1.0", None, None), + ("240612.4", "2.0.0", None, None), + ("3.0.0", "3.0.0", None, None), + ("4.0.0", "3.2.0", None, None), + ("5.0.0", None, KeyError, "Model version not available in the Hub"), + ("Blah", None, KeyError, "Bad semantic version"), + ], +) +def test_proprietary_model( + input_version, expected_version, expected_exception, expected_message, sagemaker_session +): + sagemaker_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0", "HubContentSearchKeywords": []}, + {"HubContentVersion": "1.1.0", "HubContentSearchKeywords": []}, + { + "HubContentVersion": "2.0.0", + "HubContentSearchKeywords": ["@marketplace-version:240612.4"], + }, + { + "HubContentVersion": "3.0.0", + "HubContentSearchKeywords": ["@marketplace-version:240612.5"], + }, + { + "HubContentVersion": "3.1.0", + "HubContentSearchKeywords": ["@marketplace-version:3.0.0"], + }, + { + "HubContentVersion": "3.2.0", + "HubContentSearchKeywords": ["@marketplace-version:4.0.0"], + }, + ] + } + + if expected_exception: + with pytest.raises(expected_exception, match=expected_message): + _test_proprietary_model(input_version, expected_version, sagemaker_session) + else: + _test_proprietary_model(input_version, expected_version, sagemaker_session) + + +def _test_proprietary_model(input_version, expected_version, sagemaker_session): + result = hub_utils.get_hub_model_version( + hub_model_name=MOCK_MODEL_ID, + hub_model_type="Model", + hub_name="blah", + sagemaker_session=sagemaker_session, + hub_model_version=input_version, + ) + + assert result == expected_version + + +@pytest.mark.parametrize( + "get_model_specs_attr, get_model_specs_response, expected, expected_exception, expected_message", + [ + (False, None, [], None, None), + (True, None, [], None, None), + (True, [], [], None, None), + (True, ["OPEN_WEIGHTS"], [JumpStartModelType.OPEN_WEIGHTS], None, None), + ( + True, + ["OPEN_WEIGHTS", "PROPRIETARY"], + [JumpStartModelType.OPEN_WEIGHTS, JumpStartModelType.PROPRIETARY], + None, + None, + ), + ], +) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_validate_hub_service_model_id_and_get_type( + mock_get_model_specs, + get_model_specs_attr, + get_model_specs_response, + expected, + expected_exception, + expected_message, +): + mock_object = MagicMock() + if get_model_specs_attr: + mock_object.model_types = get_model_specs_response + mock_get_model_specs.return_value = mock_object + + result = _validate_hub_service_model_id_and_get_type(model_id="blah", hub_arn="blah") + + assert result == expected diff --git a/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py new file mode 100644 index 0000000000..8a2da5b165 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_parser_utils.py @@ -0,0 +1,53 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from sagemaker.jumpstart.hub.parser_utils import camel_to_snake +from sagemaker.jumpstart.hub.parsers import make_model_specs_from_describe_hub_content_response +from sagemaker.jumpstart.hub.interfaces import HubModelDocument +from tests.unit.sagemaker.jumpstart.constants import HUB_MODEL_DOCUMENT_DICTS +from unittest.mock import MagicMock +from sagemaker.jumpstart.types import HubContentType + + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + + +@pytest.mark.parametrize( + "input_string, expected", + [ + ("camelCase", "camel_case"), + ("PascalCase", "pascal_case"), + ("already_snake", "already_snake"), + ("", ""), + ("A", "a"), + ("PascalCase123", "pascal_case123"), + ("123StartWithNumber", "123_start_with_number"), + ], +) +def test_parse_camelCase(input_string, expected): + assert expected == camel_to_snake(input_string) + + +def test_make_model_specs_from_describe_hub_content_response(): + mock_describe_response = MagicMock() + region = "us-west-2" + mock_describe_response.hub_content_type = HubContentType.MODEL + mock_describe_response.get_hub_region.return_value = region + mock_describe_response.hub_content_version = "1.0.0" + json_obj = HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"] + mock_describe_response.hub_content_document = HubModelDocument(json_obj=json_obj, region=region) + + make_model_specs_from_describe_hub_content_response(mock_describe_response) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py new file mode 100644 index 0000000000..5745a7f79c --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -0,0 +1,295 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import patch, Mock +from sagemaker.jumpstart.types import HubArnExtractedInfo +from sagemaker.jumpstart.constants import ( + JUMPSTART_DEFAULT_REGION_NAME, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) +from sagemaker.jumpstart.hub import parser_utils, utils + + +def test_get_info_from_hub_resource_arn(): + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Model/my-mock-model/1.0.2" + ) + assert utils.get_info_from_hub_resource_arn(model_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Model", + hub_content_name="my-mock-model", + hub_content_version="1.0.2", + ) + + notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Notebook/my-mock-notebook/1.0.2" + assert utils.get_info_from_hub_resource_arn(notebook_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Notebook", + hub_content_name="my-mock-notebook", + hub_content_version="1.0.2", + ) + + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" + assert utils.get_info_from_hub_resource_arn(hub_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + ) + + invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "nonsense-string" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + +def test_construct_hub_arn_from_name(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.boto_region_name = "us-west-2" + hub_name = "my-cool-hub" + + assert ( + utils.construct_hub_arn_from_name(hub_name=hub_name, session=mock_sagemaker_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-cool-hub" + ) + + assert ( + utils.construct_hub_arn_from_name( + hub_name=hub_name, region="us-east-1", session=mock_sagemaker_session + ) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-cool-hub" + ) + + +def test_construct_hub_arn_from_name_with_session_none(): + hub_name = "my-cool-hub" + account_id = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.account_id() + boto_region_name = DEFAULT_JUMPSTART_SAGEMAKER_SESSION.boto_region_name + + assert ( + utils.construct_hub_arn_from_name(hub_name=hub_name, session=None) + == f"arn:aws:sagemaker:{boto_region_name}:{account_id}:hub/{hub_name}" + ) + + +def test_construct_hub_model_arn_from_inputs(): + model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/1.0.2" + ) + + version = "*" + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/*" + ) + + +def test_construct_hub_model_reference_arn_from_inputs(): + model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + hub_content_arn_prefix = "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub" + + assert ( + utils.construct_hub_model_reference_arn_from_inputs(hub_arn, model_name, version) + == hub_content_arn_prefix + "/ModelReference/pytorch-ic-imagenet-v2/1.0.2" + ) + + version = "*" + assert ( + utils.construct_hub_model_reference_arn_from_inputs(hub_arn, model_name, version) + == hub_content_arn_prefix + "/ModelReference/pytorch-ic-imagenet-v2/*" + ) + + +def test_generate_hub_arn_for_init_kwargs(): + hub_name = "my-hub-name" + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock default session with default values + mock_default_session = Mock() + mock_default_session.account_id.return_value = "123456789123" + mock_default_session.boto_region_name = JUMPSTART_DEFAULT_REGION_NAME + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, session=mock_default_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, "us-east-1", session=mock_default_session) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, "eu-west-1", mock_custom_session) + == "arn:aws:sagemaker:eu-west-1:000000000000:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, None, mock_custom_session) + == "arn:aws:sagemaker:us-east-2:000000000000:hub/my-hub-name" + ) + + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, session=mock_default_session) == hub_arn + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", session=mock_default_session) + == hub_arn + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn + ) + + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn + + +def test_is_gated_bucket(): + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True + + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-east-1") is True + + assert utils.is_gated_bucket("jumpstart-cache-prod-us-west-2") is False + + assert utils.is_gated_bucket("") is False + + +@patch("sagemaker.session.Session") +def test_get_hub_model_version_success(mock_session): + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" + hub_model_version = "1.0.0" + mock_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == "1.0.0" + + +@patch("sagemaker.session.Session") +def test_get_hub_model_version_None(mock_session): + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" + hub_model_version = None + mock_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == "2.0.0" + + +@patch("sagemaker.session.Session") +def test_get_hub_model_version_wildcard_char(mock_session): + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" + hub_model_version = "*" + mock_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == "2.0.0" + + +def test_walk_and_apply_json(): + test_json = { + "CamelCaseKey": "value", + "CamelCaseObjectKey": { + "CamelCaseObjectChildOne": "value1", + "CamelCaseObjectChildTwo": "value2", + }, + "IgnoreMyChildren": {"ShouldNotBeTouchedOne": "const1", "ShouldNotBeTouchedTwo": "const2"}, + "ShouldNotIgnoreMyChildren": {"NopeNope": "no"}, + } + + result = parser_utils.walk_and_apply_json( + test_json, parser_utils.camel_to_snake, ["ignore_my_children"] + ) + assert result == { + "camel_case_key": "value", + "camel_case_object_key": { + "camel_case_object_child_one": "value1", + "camel_case_object_child_two": "value2", + }, + "ignore_my_children": { + "ShouldNotBeTouchedOne": "const1", + "ShouldNotBeTouchedTwo": "const2", + }, + "should_not_ignore_my_children": {"nope_nope": "no"}, + } + + +def test_walk_and_apply_json_no_stop(): + test_json = { + "CamelCaseKey": "value", + "CamelCaseObjectKey": { + "CamelCaseObjectChildOne": "value1", + "CamelCaseObjectChildTwo": "value2", + }, + "CamelCaseObjectListKey": {"instance.ml.type.xlarge": [{"ShouldChangeMe": "string"}]}, + } + + result = parser_utils.walk_and_apply_json(test_json, parser_utils.camel_to_snake) + assert result == { + "camel_case_key": "value", + "camel_case_object_key": { + "camel_case_object_child_one": "value1", + "camel_case_object_child_two": "value2", + }, + "camel_case_object_list_key": {"instance.ml.type.xlarge": [{"should_change_me": "string"}]}, + } diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 8b00eb5bcd..d9b126f651 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -15,8 +15,11 @@ from typing import Optional, Set from unittest import mock import unittest + +import pandas as pd from mock import MagicMock, Mock import pytest +from sagemaker_core.shapes import ModelAccessConfig from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.jumpstart.artifacts.environment_variables import ( _retrieve_default_environment_variables, @@ -40,12 +43,19 @@ from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit.sagemaker.jumpstart.utils import ( + get_prototype_spec_with_configs, get_spec_from_base_spec, get_special_model_spec, overwrite_dictionary, get_special_model_spec_for_inference_component_based_endpoint, get_prototype_manifest, get_prototype_model_spec, + get_base_spec_with_prototype_configs, + get_mock_init_kwargs, + get_base_deployment_configs, + get_base_spec_with_prototype_configs_with_missing_benchmarks, + append_instance_stat_metrics, + append_gated_draft_model_specs_to_jumpstart_model_spec, ) import boto3 @@ -60,13 +70,17 @@ class ModelTest(unittest.TestCase): - mock_session_empty_config = MagicMock(sagemaker_config={}) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -80,6 +94,7 @@ def test_non_prepacked( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_jumpstart_model_factory_logger: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -139,9 +154,14 @@ def test_non_prepacked( endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -154,6 +174,7 @@ def test_non_prepacked_inference_component_based_endpoint( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -219,9 +240,14 @@ def test_non_prepacked_inference_component_based_endpoint( endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -234,6 +260,7 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -253,8 +280,15 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom mock_model_init.assert_called_once_with( image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:" "1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", - model_data="s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/" - "v1.0.0/infer-prepack-huggingface-txt2img-conflictx-complex-lineart.tar.gz", + model_data={ + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-prod-us-west-2/" + "huggingface-txt2img/huggingface-txt2img-conflictx" + "-complex-lineart/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, env={ "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -265,7 +299,8 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom predictor_cls=Predictor, role=execution_role, sagemaker_session=sagemaker_session, - enable_network_isolation=False, + enable_network_isolation=True, + name="hf-txt2img-conflictx-complex-lineart-7777", ) custom_resource_requirements = ResourceRequirements( @@ -287,15 +322,22 @@ def test_non_prepacked_inference_component_based_endpoint_no_default_pass_custom wait=True, tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "js-model-class-model-prepacked"}, - {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.1.0"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "2.0.3"}, ], endpoint_logging=False, resources=custom_resource_requirements, endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, + endpoint_name="hf-txt2img-conflictx-complex-lineart-7777", ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -307,7 +349,10 @@ def test_prepacked( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, ): + mock_sagemaker_timestamp.return_value = "8675309" mock_model_deploy.return_value = default_predictor mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -325,8 +370,15 @@ def test_prepacked( mock_model_init.assert_called_once_with( image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:" "1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04", - model_data="s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/" - "v1.0.0/infer-prepack-huggingface-txt2img-conflictx-complex-lineart.tar.gz", + model_data={ + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-prod-us-west-2/" + "huggingface-txt2img/huggingface-txt2img-conflictx" + "-complex-lineart/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, env={ "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -337,7 +389,8 @@ def test_prepacked( predictor_cls=Predictor, role=execution_role, sagemaker_session=sagemaker_session, - enable_network_isolation=False, + enable_network_isolation=True, + name="hf-txt2img-conflictx-complex-lineart-8675309", ) model.deploy() @@ -345,20 +398,26 @@ def test_prepacked( mock_model_deploy.assert_called_once_with( initial_instance_count=1, instance_type="ml.p3.2xlarge", + endpoint_name="hf-txt2img-conflictx-complex-lineart-8675309", wait=True, tags=[ {"Key": JumpStartTag.MODEL_ID, "Value": "js-model-class-model-prepacked"}, - {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.1.0"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "2.0.3"}, ], endpoint_logging=False, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.model.LOGGER.warning") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_no_compiled_model_warning_log_js_models( @@ -370,6 +429,7 @@ def test_no_compiled_model_warning_log_js_models( mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, mock_warning: mock.Mock(), + mock_get_jumpstart_configs: mock.Mock, ): mock_timestamp.return_value = "1234" @@ -390,11 +450,16 @@ def test_no_compiled_model_warning_log_js_models( mock_warning.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_eula_gated_conditional_s3_prefix_metadata_model( @@ -405,6 +470,7 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( mock_create_model: mock.Mock, mock_endpoint_from_production_variants: mock.Mock, mock_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_timestamp.return_value = "1234" @@ -452,16 +518,21 @@ def test_eula_gated_conditional_s3_prefix_metadata_model( ], ) + @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.model.Model.register") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_proprietary_model_endpoint( self, + mock_model_register: mock.Mock, mock_model_deploy: mock.Mock, mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, @@ -469,7 +540,9 @@ def test_proprietary_model_endpoint( mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): + mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {} mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) @@ -493,8 +566,17 @@ def test_proprietary_model_endpoint( enable_network_isolation=False, ) + model.register() model.deploy() + mock_model_register.assert_called_once_with( + model_type=JumpStartModelType.PROPRIETARY, + content_types=["application/json"], + response_types=["application/json"], + model_package_group_name=model_id, + source_uri=model.model_package_arn, + ) + mock_model_deploy.assert_called_once_with( initial_instance_count=1, instance_type="ml.p4de.24xlarge", @@ -509,6 +591,7 @@ def test_proprietary_model_endpoint( container_startup_health_check_timeout=600, ) + @mock.patch("sagemaker.jumpstart.model.get_jumpstart_configs") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -520,7 +603,9 @@ def test_deprecated( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): + mock_get_jumpstart_configs.side_effect = lambda *args, **kwargs: {} mock_model_deploy.return_value = default_predictor mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -536,6 +621,9 @@ def test_deprecated( JumpStartModel(model_id=model_id, tolerate_deprecated_model=True).deploy() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @@ -547,6 +635,7 @@ def test_vulnerable( mock_model_init: mock.Mock, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -610,9 +699,14 @@ def test_model_use_kwargs(self): deploy_kwargs=all_deploy_kwargs_used, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.factory.model.environment_variables.retrieve_default") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -625,6 +719,7 @@ def evaluate_model_workflow_with_kwargs( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_retrieve_environment_variables: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, init_kwargs: Optional[dict] = None, deploy_kwargs: Optional[dict] = None, ): @@ -698,8 +793,9 @@ def test_jumpstart_model_kwargs_match_parent_class(self): Please add the new argument to the skip set below, and reach out to JumpStart team.""" - init_args_to_skip: Set[str] = set([]) - deploy_args_to_skip: Set[str] = set(["kwargs"]) + init_args_to_skip: Set[str] = set(["model_reference_arn"]) + deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn", "update_endpoint"]) + deploy_args_removed_at_deploy_time: Set[str] = set(["model_access_configs"]) parent_class_init = Model.__init__ parent_class_init_args = set(signature(parent_class_init).parameters.keys()) @@ -715,6 +811,8 @@ def test_jumpstart_model_kwargs_match_parent_class(self): "tolerate_deprecated_model", "instance_type", "model_package_arn", + "config_name", + "hub_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -724,9 +822,18 @@ def test_jumpstart_model_kwargs_match_parent_class(self): js_class_deploy = JumpStartModel.deploy js_class_deploy_args = set(signature(js_class_deploy).parameters.keys()) - assert js_class_deploy_args - parent_class_deploy_args == set() - assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip + assert ( + js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time + == set() + ) + assert ( + parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time + == deploy_args_to_skip + ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @@ -735,6 +842,7 @@ def test_validate_model_id_and_get_type( mock_validate_model_id_and_get_type: mock.Mock, mock_init: mock.Mock, mock_get_init_kwargs: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS JumpStartModel(model_id="valid_model_id") @@ -743,9 +851,14 @@ def test_validate_model_id_and_get_type( with pytest.raises(ValueError): JumpStartModel(model_id="invalid_model_id") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -758,6 +871,7 @@ def test_no_predictor_returns_default_predictor( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -782,17 +896,24 @@ def test_no_predictor_returns_default_predictor( model_id=model_id, model_version="*", region=region, + hub_arn=None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + config_name=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -805,6 +926,7 @@ def test_no_predictor_yes_async_inference_config( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -826,9 +948,14 @@ def test_no_predictor_yes_async_inference_config( mock_get_default_predictor.assert_not_called() + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -841,6 +968,7 @@ def test_yes_predictor_returns_default_predictor( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_get_default_predictor: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_get_default_predictor.return_value = default_predictor_with_presets @@ -862,10 +990,15 @@ def test_yes_predictor_returns_default_predictor( self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.JumpStartModelsAccessor.reset_cache") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -877,6 +1010,7 @@ def test_model_id_not_found_refeshes_cache_inference( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.side_effect = [False, False] @@ -903,6 +1037,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -910,6 +1045,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), ] ) @@ -934,6 +1070,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -941,10 +1078,14 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), ] ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -952,6 +1093,7 @@ def test_jumpstart_model_tags( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -968,7 +1110,7 @@ def test_jumpstart_model_tags( js_tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "env-var-variant-model"}, - {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "1.0.0"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "1.6.2"}, ] self.assertEqual( @@ -981,6 +1123,9 @@ def test_jumpstart_model_tags( [{"Key": "blah", "Value": "blahagain"}] + js_tags, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -988,6 +1133,7 @@ def test_jumpstart_model_tags_disabled( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1015,6 +1161,9 @@ def test_jumpstart_model_tags_disabled( [{"Key": "blah", "Value": "blahagain"}], ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1022,6 +1171,7 @@ def test_jumpstart_model_package_arn( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1042,13 +1192,16 @@ def test_jumpstart_model_package_arn( self.assertEqual( mock_session.create_model.call_args[0][2], { - "ModelPackageName": "arn:aws:sagemaker:us-west-2:594846645681:model-package" - "/llama2-7b-f-e46eb8a833643ed58aaccd81498972c3" + "ModelPackageName": "arn:aws:sagemaker:us-west-2:594846645681:" + "model-package/llama2-7b-f-v4-71eeccf76ddf33f2a18d2e16b9c7f302" }, ) self.assertIn(tag, mock_session.create_model.call_args[1]["tags"]) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1056,6 +1209,7 @@ def test_jumpstart_model_package_arn_override( self, mock_get_model_specs: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1091,8 +1245,13 @@ def test_jumpstart_model_package_arn_override( }, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn_unsupported_region( @@ -1100,6 +1259,7 @@ def test_jumpstart_model_package_arn_unsupported_region( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1111,15 +1271,21 @@ def test_jumpstart_model_package_arn_unsupported_region( mock_session.return_value = MagicMock(sagemaker_config={}) with pytest.raises(ValueError) as e: - JumpStartModel(model_id=model_id, region="us-east-2") + JumpStartModel(model_id=model_id, region="us-west-1") assert ( str(e.value) == "Model package arn for 'js-model-package-arn' not supported in " - "us-east-2. Please try one of the following regions: us-west-2, us-east-1." + "us-west-1. Please try one of the following regions: " + "us-west-2, us-east-2, us-east-1, eu-west-1, ap-southeast-1, ap-southeast-2." ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1134,6 +1300,7 @@ def test_model_data_s3_prefix_override( mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1183,8 +1350,14 @@ def test_model_data_s3_prefix_override( '"S3DataType": "S3Prefix", "CompressionType": "None"}}', ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1198,7 +1371,12 @@ def test_model_data_s3_prefix_model( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, ): + + mock_sagemaker_timestamp.return_value = "8675309" + mock_model_deploy.return_value = default_predictor mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1211,24 +1389,35 @@ def test_model_data_s3_prefix_model( JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge") mock_model_init.assert_called_once_with( - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-gpu-py38", - model_data={ - "S3DataSource": { - "S3Uri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/", - "S3DataType": "S3Prefix", - "CompressionType": "None", - } + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-" + "pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + model_data="s3://jumpstart-cache-prod-us-west-2/huggingface-infer/" + "prepack/v1.1.2/infer-prepack-huggingface-text2text-flan-t5-xxl-fp16.tar.gz", + env={ + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "TS_DEFAULT_WORKERS_PER_MODEL": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", }, predictor_cls=Predictor, role=execution_role, sagemaker_session=sagemaker_session, - enable_network_isolation=False, + enable_network_isolation=True, + name="hf-text2text-flan-t5-xxl-fp16-8675309", ) mock_js_info_logger.assert_not_called() + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1242,7 +1431,12 @@ def test_model_artifact_variant_model( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, ): + + mock_sagemaker_timestamp.return_value = "8675309" + mock_model_deploy.return_value = default_predictor mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -1256,8 +1450,7 @@ def test_model_artifact_variant_model( JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge") mock_model_init.assert_called_once_with( - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-" - "inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.10.0-gpu-py38", model_data="s3://jumpstart-cache-prod-us-west-2/hello-world-1", env={ "SAGEMAKER_PROGRAM": "inference.py", @@ -1270,15 +1463,24 @@ def test_model_artifact_variant_model( role=execution_role, sagemaker_session=sagemaker_session, enable_network_isolation=True, + name="pt-ic-mobilenet-v2-8675309", ) mock_model_init.reset_mock() - JumpStartModel(model_id=model_id, instance_type="ml.p99.xlarge") + JumpStartModel(model_id=model_id, instance_type="ml.p3.2xlarge") mock_model_init.assert_called_once_with( - image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.5.0-gpu-py3", - model_data="s3://jumpstart-cache-prod-us-west-2/basfsdfssf", + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-inference:1.10.0-gpu-py38", + model_data={ + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-prod-us-west-2/" + "pytorch-ic/pytorch-ic-mobilenet-v2/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, env={ "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -1290,10 +1492,72 @@ def test_model_artifact_variant_model( role=execution_role, sagemaker_session=sagemaker_session, enable_network_isolation=True, + name="pt-ic-mobilenet-v2-8675309", ) + @mock.patch("sagemaker.jumpstart.model.get_model_info_from_endpoint") + @mock.patch("sagemaker.jumpstart.model.JumpStartModel.__init__") + def test_attach( + self, + mock_js_model_init, + mock_get_model_info_from_endpoint, + ): + mock_js_model_init.return_value = None + mock_get_model_info_from_endpoint.return_value = ( + "model-id", + "model-version", + None, + None, + None, + ) + val = JumpStartModel.attach("some-endpoint") + mock_get_model_info_from_endpoint.assert_called_once_with( + endpoint_name="some-endpoint", + inference_component_name=None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) + mock_js_model_init.assert_called_once_with( + model_id="model-id", + model_version="model-version", + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + hub_name=None, + ) + assert isinstance(val, JumpStartModel) + + mock_get_model_info_from_endpoint.reset_mock() + JumpStartModel.attach("some-endpoint", model_id="some-id") + mock_get_model_info_from_endpoint.assert_called_once_with( + endpoint_name="some-endpoint", + inference_component_name=None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) + + mock_get_model_info_from_endpoint.reset_mock() + JumpStartModel.attach("some-endpoint", model_id="some-id", model_version="some-version") + mock_get_model_info_from_endpoint.assert_called_once_with( + endpoint_name="some-endpoint", + inference_component_name=None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) + + # providing model id, version, and ic name should bypass check with endpoint tags + mock_get_model_info_from_endpoint.reset_mock() + JumpStartModel.attach( + "some-endpoint", + model_id="some-id", + model_version="some-version", + inference_component_name="some-ic-name", + ) + + mock_get_model_info_from_endpoint.assert_not_called() + + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.model.Model.register") @@ -1305,6 +1569,7 @@ def test_model_registry_accept_and_response_types( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -1320,10 +1585,15 @@ def test_model_registry_accept_and_response_types( model.register() mock_model_register.assert_called_once_with( - content_types=["application/x-text"], + model_type=JumpStartModelType.OPEN_WEIGHTS, + content_types=["application/x-text", "application/json"], response_types=["application/json;verbose", "application/json"], + model_package_group_name=model.model_id, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.get_default_predictor") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @@ -1337,6 +1607,7 @@ def test_jumpstart_model_session( mock_deploy, mock_init, get_default_predictor, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = True @@ -1370,6 +1641,10 @@ def test_jumpstart_model_session( assert len(s3_clients) == 1 assert list(s3_clients)[0] == session.s3_client + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch.dict( "sagemaker.jumpstart.cache.os.environ", { @@ -1378,7 +1653,9 @@ def test_jumpstart_model_session( }, ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") - @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) @@ -1388,31 +1665,668 @@ def test_model_local_mode( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, ): + + mock_sagemaker_timestamp.return_value = "8675309" + mock_get_model_specs.side_effect = get_prototype_model_spec mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) mock_model_deploy.return_value = default_predictor - model_id, _ = "pytorch-eqa-bert-base-cased", "*" + model_id, _ = "pytorch-ic-mobilenet-v2", "*" mock_session.return_value = sagemaker_session - model = JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge") + model = JumpStartModel(model_id=model_id, instance_type="ml.m5.xlarge") model.deploy() mock_model_deploy.assert_called_once_with( initial_instance_count=1, - instance_type="ml.p2.xlarge", + instance_type="ml.m5.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, + ], + wait=True, + endpoint_logging=False, + endpoint_name="pt-ic-mobilenet-v2-8675309", + ) + + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_initialization_with_config_name( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + + mock_sagemaker_timestamp.return_value = "8675309" + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, config_name="neuron-inference") + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + endpoint_name="pt-ic-mobilenet-v2-8675309", + ) + + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + + mock_sagemaker_timestamp.return_value = "8675309" + mock_get_model_specs.side_effect = get_prototype_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.m5.large", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, + ], + wait=True, + endpoint_logging=False, + endpoint_name="pt-ic-mobilenet-v2-8675309", + ) + + mock_get_model_specs.reset_mock() + mock_model_deploy.reset_mock() + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + model.set_deployment_config("neuron-inference", "ml.inf2.2xlarge") + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.2xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + endpoint_name="pt-ic-mobilenet-v2-8675309", + ) + mock_model_deploy.reset_mock() + model.set_deployment_config("neuron-inference", "ml.inf2.xlarge") + + assert model.config_name == "neuron-inference" + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.inf2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + ], + wait=True, + endpoint_logging=False, + endpoint_name="pt-ic-mobilenet-v2-8675309", + ) + + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config_and_deploy_for_gated_draft_model( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + # WHERE + + mock_sagemaker_timestamp.return_value = "8675309" + mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id = "pytorch-ic-mobilenet-v2" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + # WHEN + model.deploy( + model_access_configs={"pytorch-ic-mobilenet-v2": ModelAccessConfig(accept_eula=True)} + ) + + # THEN + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.m5.large", tags=[ - {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, - {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, ], wait=True, endpoint_logging=False, + endpoint_name="pt-ic-mobilenet-v2-8675309", + ) + + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config_and_deploy_for_gated_draft_model_no_model_access_configs( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + ): + # WHERE + mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id = "pytorch-ic-mobilenet-v2" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + # WHEN / THEN + with self.assertRaises(ValueError): + model.deploy() + + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.model.Model.__init__") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_deployment_config_additional_model_data_source( + self, + mock_model_init: mock.Mock, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + + mock_sagemaker_timestamp.return_value = "8675309" + + mock_session.return_value = sagemaker_session + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + model = JumpStartModel(model_id=model_id, config_name="gpu-accelerated") + + mock_model_init.assert_called_once_with( + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.10.0-cpu-py38", + model_data={ + "S3DataSource": { + "S3Uri": "s3://jumpstart-cache-prod-us-west-2/pytorch-ic" + "/pytorch-ic-mobilenet-v2/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + env={ + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + predictor_cls=Predictor, + role=execution_role, + sagemaker_session=sagemaker_session, + enable_network_isolation=True, + name="pt-ic-mobilenet-v2-8675309", + additional_model_data_sources=[ + { + "ChannelName": "draft_model_name", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://sagemaker-sd-models-prod-us-west-2/key/to/draft/model/artifact/", + "ModelAccessConfig": {"AcceptEula": False}, + }, + "HostingEulaKey": None, + } + ], + ) + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.m5.large", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-accelerated"}, + ], + wait=True, + endpoint_logging=False, + endpoint_name="pt-ic-mobilenet-v2-8675309", + ) + + # TODO: Commenting out this test due to flakiness. Need to mock the session + # @mock.patch( + # "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + # ) + # @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + # @mock.patch("sagemaker.jumpstart.factory.model.Session") + # @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + # @mock.patch("sagemaker.jumpstart.model.Model.deploy") + # @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + # def test_model_set_deployment_config_model_package( + # self, + # mock_model_deploy: mock.Mock, + # mock_get_model_specs: mock.Mock, + # mock_session: mock.Mock, + # mock_get_manifest: mock.Mock, + # mock_get_jumpstart_configs: mock.Mock, + # ): + # mock_get_model_specs.side_effect = get_prototype_spec_with_configs + # mock_get_manifest.side_effect = ( + # lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + # ) + # mock_model_deploy.return_value = default_predictor + + # model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + # mock_session.return_value = sagemaker_session + + # model = JumpStartModel(model_id=model_id) + + # assert model.config_name == "neuron-inference" + + # model.deploy() + + # mock_model_deploy.assert_called_once_with( + # initial_instance_count=1, + # instance_type="ml.inf2.xlarge", + # tags=[ + # {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + # {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + # {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "neuron-inference"}, + # ], + # wait=True, + # endpoint_logging=False, + # ) + + # mock_model_deploy.reset_mock() + + # model.set_deployment_config( + # config_name="gpu-inference-model-package", instance_type="ml.p2.xlarge" + # ) + + # assert ( + # model.model_package_arn + # == "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + # ) + # model.deploy() + + # mock_model_deploy.assert_called_once_with( + # initial_instance_count=1, + # instance_type="ml.p2.xlarge", + # tags=[ + # {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + # {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + # {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "gpu-inference-model-package"}, + # ], + # wait=True, + # endpoint_logging=False, + # ) + + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_set_deployment_config_incompatible_instance_type_or_name( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + mock_sagemaker_timestamp.return_value = "8675309" + + mock_get_model_specs.side_effect = get_prototype_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + assert model.config_name is None + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.m5.large", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-ic-mobilenet-v2"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "3.0.6"}, + ], + wait=True, + endpoint_logging=False, + endpoint_name="pt-ic-mobilenet-v2-8675309", + ) + + mock_get_model_specs.reset_mock() + mock_model_deploy.reset_mock() + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + + with pytest.raises(ValueError) as error: + model.set_deployment_config("neuron-inference-unknown-name", "ml.inf2.32xlarge") + assert "Cannot find Jumpstart config name neuron-inference-unknown-name. " in str(error) + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_list_deployment_configs( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + configs = model.list_deployment_configs() + + self.assertEqual(configs, get_base_deployment_configs(True)) + + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_list_deployment_configs_empty( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + ): + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_special_model_spec(model_id="gemma-model") + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + configs = model.list_deployment_configs() + + self.assertTrue(len(configs) == 0) + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_retrieve_deployment_config( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + expected = get_base_deployment_configs()[0] + config_name = expected.get("DeploymentConfigName") + instance_type = expected.get("InstanceType") + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs( + model_id, config_name + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + model.set_deployment_config(config_name, instance_type) + + self.assertEqual(model.deployment_config, expected) + + mock_get_init_kwargs.reset_mock() + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_display_benchmark_metrics( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + model.display_benchmark_metrics() + model.display_benchmark_metrics(instance_type="g5.12xlarge") + + @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") + @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_benchmark_metrics( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, + mock_verify_model_region_and_return_specs: mock.Mock, + mock_get_init_kwargs: mock.Mock, + ): + model_id, _ = "pytorch-ic-mobilenet-v2", "*" + + mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id) + mock_verify_model_region_and_return_specs.side_effect = ( + lambda *args, **kwargs: get_base_spec_with_prototype_configs() + ) + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) + mock_get_model_specs.side_effect = get_prototype_spec_with_configs + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id) + + df = model.benchmark_metrics + + self.assertTrue(isinstance(df, pd.DataFrame)) def test_jumpstart_model_requires_model_id(): diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 70409704e6..a0299ebb1a 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -60,6 +60,9 @@ class IntelligentDefaultsModelTest(unittest.TestCase): region = "us-west-2" sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -77,6 +80,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -95,12 +99,15 @@ def test_without_arg_overwrites_without_kwarg_collisions_with_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), config_role) + self.assertEqual(mock_model_init.call_args[1].get("role"), config_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -118,6 +125,7 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -139,14 +147,17 @@ def test_all_arg_overwrites_without_kwarg_collisions_with_config( role=override_role, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -164,6 +175,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -185,14 +197,17 @@ def test_without_arg_overwrites_all_kwarg_collisions_with_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 2) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 2) - self.assertEquals(mock_model_init.call_args[1].get("role"), config_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), config_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), config_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -210,6 +225,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -233,14 +249,17 @@ def test_with_arg_overwrites_all_kwarg_collisions_with_config( enable_network_isolation=override_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -258,6 +277,7 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -279,14 +299,17 @@ def test_without_arg_overwrites_all_kwarg_collisions_without_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 2) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 2) - self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), execution_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), metadata_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -304,6 +327,7 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -326,14 +350,17 @@ def test_with_arg_overwrites_all_kwarg_collisions_without_config( enable_network_isolation=override_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -351,6 +378,7 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -370,11 +398,14 @@ def test_without_arg_overwrites_without_kwarg_collisions_without_config( model_id=model_id, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), execution_role) + self.assertEqual(mock_model_init.call_args[1].get("role"), execution_role) assert "enable_network_isolation" not in mock_model_init.call_args[1] + @mock.patch( + "sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {} + ) @mock.patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @mock.patch("sagemaker.jumpstart.model.Model.__init__") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @@ -392,6 +423,7 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( mock_retrieve_kwargs: mock.Mock, mock_model_init: mock.Mock, mock_validate_model_id_and_get_type: mock.Mock, + mock_get_jumpstart_configs: mock.Mock, ): mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS @@ -413,10 +445,10 @@ def test_with_arg_overwrites_without_kwarg_collisions_without_config( enable_network_isolation=override_enable_network_isolation, ) - self.assertEquals(mock_get_sagemaker_config_value.call_count, 1) + self.assertEqual(mock_get_sagemaker_config_value.call_count, 1) - self.assertEquals(mock_model_init.call_args[1].get("role"), override_role) - self.assertEquals( + self.assertEqual(mock_model_init.call_args[1].get("role"), override_role) + self.assertEqual( mock_model_init.call_args[1].get("enable_network_isolation"), override_enable_network_isolation, ) diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 3d9b5cef6a..75aa93a920 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -34,7 +34,7 @@ from sagemaker.jumpstart.artifacts.model_uris import _retrieve_model_uri from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType -from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec +from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec from tests.unit.sagemaker.workflow.conftest import mock_client @@ -176,7 +176,7 @@ def test_retrieve_training_artifact_key(self): "image_uri": "$alias_ecr_uri_1", }, "properties": { - "artifact_key": "in/the/way", + "training_artifact_key": "in/the/way", }, } }, @@ -220,12 +220,12 @@ def test_retrieve_training_artifact_key(self): @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") class RetrieveKwargsTest(unittest.TestCase): - model_id, model_version = "pytorch-eqa-bert-base-cased", "*" + model_id, model_version = "variant-model", "*" region = "us-west-2" def test_model_kwargs(self, patched_get_model_specs): - patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_get_model_specs.side_effect = get_special_model_spec kwargs = artifacts._retrieve_model_init_kwargs( region=self.region, @@ -242,7 +242,7 @@ def test_model_kwargs(self, patched_get_model_specs): def test_estimator_kwargs(self, patched_volume_size_supported, patched_get_model_specs): patched_volume_size_supported.return_value = False - patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_get_model_specs.side_effect = get_special_model_spec kwargs = artifacts._retrieve_estimator_init_kwargs( region=self.region, @@ -262,7 +262,7 @@ def test_estimator_kwargs_with_volume_size( ): patched_volume_size_supported.return_value = True - patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_get_model_specs.side_effect = get_special_model_spec kwargs = artifacts._retrieve_estimator_init_kwargs( region=self.region, @@ -282,7 +282,7 @@ def test_model_deploy_kwargs(self, patched_volume_size_supported, patched_get_mo patched_volume_size_supported.return_value = False - patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_get_model_specs.side_effect = get_special_model_spec kwargs = artifacts._retrieve_model_deploy_kwargs( region=self.region, @@ -300,7 +300,7 @@ def test_model_deploy_kwargs_with_volume_size( patched_volume_size_supported.return_value = True - patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_get_model_specs.side_effect = get_special_model_spec kwargs = artifacts._retrieve_model_deploy_kwargs( region=self.region, @@ -316,7 +316,7 @@ def test_model_deploy_kwargs_with_volume_size( def test_estimator_fit_kwargs(self, patched_get_model_specs): - patched_get_model_specs.side_effect = get_spec_from_base_spec + patched_get_model_specs.side_effect = get_special_model_spec kwargs = artifacts._retrieve_estimator_fit_kwargs( region=self.region, diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 50fe6da0a6..a652a11f4e 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -22,10 +22,14 @@ from mock.mock import MagicMock import pytest from mock import patch +from packaging.version import Version + +from sagemaker.jumpstart import utils from sagemaker.jumpstart.cache import ( JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JumpStartModelsCache, ) from sagemaker.jumpstart.constants import ( @@ -33,6 +37,7 @@ ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, ) from sagemaker.jumpstart.types import ( + JumpStartCachedContentValue, JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, @@ -53,6 +58,25 @@ from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region") +@patch( + "sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket" +) +@patch("boto3.client") +def test_jumpstart_cache_init(mock_boto3_client): + cache = JumpStartModelsCache() + assert cache._region == "dummy-region" + assert cache.s3_bucket_name == "dummy-bucket" + assert cache._manifest_file_s3_key == JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY + assert cache._proprietary_manifest_s3_key == JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY + assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION + mock_boto3_client.assert_called_once_with("s3", region_name="dummy-region") + + # Some callers override the session to None, should still be set to default + cache = JumpStartModelsCache(sagemaker_session=None) + assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION + + @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_header(): @@ -160,6 +184,30 @@ def test_jumpstart_cache_get_header(): semantic_version_str="1.0.*", ) + assert JumpStartModelHeader( + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.13.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json", + } + ) == cache.get_header( + model_id="meta-textgeneration-llama-2-7b", + semantic_version_str="*", + ) + + assert JumpStartModelHeader( + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.13.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json", + } + ) == cache.get_header( + model_id="meta-textgeneration-llama-2-7b", + semantic_version_str="4.*", + ) + assert JumpStartModelHeader( { "model_id": "ai21-summarization", @@ -205,8 +253,11 @@ def test_jumpstart_cache_get_header(): ) assert ( "Unable to find model manifest for 'pytorch-ic-imagenet-inception-v3-classification-4' with " - "version '3.*'. Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " - "for updated list of models. Consider using model ID 'pytorch-ic-imagenet-inception-v3-" + "version '3.*'. Specify a different model ID or try a different AWS Region. " + "For a list of available models, see " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. " + "Consider using model ID " + "'pytorch-ic-imagenet-inception-v3-" "classification-4' with version '2.0.0'." ) in str(e.value) @@ -214,8 +265,9 @@ def test_jumpstart_cache_get_header(): cache.get_header(model_id="pytorch-ic-", semantic_version_str="*") assert ( "Unable to find model manifest for 'pytorch-ic-' with version '*'. " - "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " - "for updated list of models. " + "Specify a different model ID or try a different AWS Region. " + "For a list of available models, see " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. " "Did you mean to use model ID 'pytorch-ic-imagenet-inception-v3-classification-4'?" ) in str(e.value) @@ -223,8 +275,9 @@ def test_jumpstart_cache_get_header(): cache.get_header(model_id="tensorflow-ic-", semantic_version_str="*") assert ( "Unable to find model manifest for 'tensorflow-ic-' with version '*'. " - "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " - "for updated list of models. " + "Specify a different model ID or try a different AWS Region. For a list " + "of available models, see " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. " "Did you mean to use model ID 'tensorflow-ic-imagenet-inception-" "v3-classification-4'?" ) in str(e.value) @@ -237,8 +290,9 @@ def test_jumpstart_cache_get_header(): ) assert ( "Unable to find model manifest for 'ai21-summarize' with version '1.1.003'. " - "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " - "for updated list of models. " + "Specify a different model ID or try a different AWS Region. " + "For a list of available models, see " + "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html. " "Did you mean to use model ID 'ai21-summarization'?" ) in str(e.value) @@ -495,8 +549,8 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache.get_manifest_file_s3_key() == manifest_file_key assert cache.get_region() == region assert cache.get_bucket() == bucket - assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert cache._content_cache._max_cache_items == max_s3_cache_items + assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._open_weight_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items @@ -535,8 +589,8 @@ def test_jumpstart_proprietary_cache_accepts_input_parameters(): ) assert cache.get_region() == region assert cache.get_bucket() == bucket - assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert cache._content_cache._max_cache_items == max_s3_cache_items + assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._proprietary_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items @@ -1113,3 +1167,199 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( ), ] ) + + +@patch.object(JumpStartModelsCache, "_retrieval_function") +def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights( + retrieval_function: Mock, +): + sm_version = Version(utils.get_sagemaker_version()) + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") + print(str(new_sm_version)) + versions = ["1.0.0", "2.9.1", "2.16.0"] + manifest = [ + { + "model_id": "test-model", + "version": version, + "min_version": "2.49.0", + "spec_key": "spec_key", + } + for version in versions + ] + + manifest.append( + { + "model_id": "test-model", + "version": "3.0.0", + "min_version": str(new_sm_version), + "spec_key": "spec_key", + } + ) + + manifest_dict = {} + for header in manifest: + header_obj = JumpStartModelHeader(header) + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( + header_obj + ) + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) + key = JumpStartVersionedModelId("test-model", "*") + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None) + + assert_key = JumpStartVersionedModelId("test-model", "2.16.0") + + assert result == assert_key + + +@patch.object(JumpStartModelsCache, "_retrieval_function") +def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights( + retrieval_function: Mock, +): + sm_version = Version(utils.get_sagemaker_version()) + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") + print(str(new_sm_version)) + versions = ["1.0.0", "2.9.1", "2.16.0"] + manifest = [ + { + "model_id": "test-model", + "version": version, + "min_version": "2.49.0", + "spec_key": "spec_key", + } + for version in versions + ] + + manifest.append( + { + "model_id": "test-model", + "version": "3.0.0", + "min_version": str(new_sm_version), + "spec_key": "spec_key", + } + ) + + manifest_dict = {} + for header in manifest: + header_obj = JumpStartModelHeader(header) + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( + header_obj + ) + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) + key = JumpStartVersionedModelId("test-model", "*") + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + result = cache._get_proprietary_manifest_key_from_model_id(key=key, value=None) + + assert_key = JumpStartVersionedModelId("test-model", "2.16.0") + + assert result == assert_key + + +@patch.object(JumpStartModelsCache, "_retrieval_function") +def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_function: Mock): + sm_version = Version(utils.get_sagemaker_version()) + new_sm_version = Version(str(sm_version.major + 1) + ".0.0") + print(str(new_sm_version)) + versions = ["abc", "2.9.1", "2.16.0"] + manifest = [ + { + "model_id": "test-model", + "version": version, + "min_version": "2.49.0", + "spec_key": "spec_key", + } + for version in versions + ] + + manifest_dict = {} + for header in manifest: + header_obj = JumpStartModelHeader(header) + manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = ( + header_obj + ) + retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict) + key = JumpStartVersionedModelId("test-model", "*") + + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None) + + assert_key = JumpStartVersionedModelId("test-model", "abc") + + assert result == assert_key + + +@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region") +@patch( + "sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket" +) +def test_get_json_file_from_s3(): + """Test _get_json_file retrieves from S3 in normal mode.""" + cache = JumpStartModelsCache() + test_key = "test/file/path.json" + test_json_data = {"key": "value"} + test_etag = "test-etag-123" + + with patch.object( + JumpStartModelsCache, + "_get_json_file_and_etag_from_s3", + return_value=(test_json_data, test_etag), + ) as mock_s3_get: + result, etag = cache._get_json_file(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST) + + mock_s3_get.assert_called_once_with(test_key) + assert result == test_json_data + assert etag == test_etag + + +@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region") +@patch( + "sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket" +) +def test_get_json_file_from_local_supported_type(): + """Test _get_json_file retrieves from local override for supported file types.""" + cache = JumpStartModelsCache() + test_key = "test/file/path.json" + test_json_data = {"key": "value"} + + with ( + patch.object(JumpStartModelsCache, "_is_local_metadata_mode", return_value=True), + patch.object( + JumpStartModelsCache, "_get_json_file_from_local_override", return_value=test_json_data + ) as mock_local_get, + ): + result, etag = cache._get_json_file(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST) + + mock_local_get.assert_called_once_with(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST) + assert result == test_json_data + assert etag is None + + +@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region") +@patch( + "sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket" +) +def test_get_json_file_local_mode_unsupported_type(): + """Test _get_json_file falls back to S3 for unsupported file types in local mode.""" + cache = JumpStartModelsCache() + test_key = "test/file/path.json" + test_json_data = {"key": "value"} + test_etag = "test-etag-123" + + with ( + patch.object(JumpStartModelsCache, "_is_local_metadata_mode", return_value=True), + patch.object( + JumpStartModelsCache, + "_get_json_file_and_etag_from_s3", + return_value=(test_json_data, test_etag), + ) as mock_s3_get, + patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") as mock_warning, + ): + result, etag = cache._get_json_file(test_key, JumpStartS3FileType.PROPRIETARY_MANIFEST) + + mock_s3_get.assert_called_once_with(test_key) + mock_warning.assert_called_once() + assert "not supported for local override" in mock_warning.call_args[0][0] + assert result == test_json_data + assert etag == test_etag diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index c00d271ef1..b1932b4ffc 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -1,9 +1,12 @@ from __future__ import absolute_import + +import datetime import json from unittest import TestCase -from unittest.mock import Mock, patch -import datetime +from unittest.mock import Mock, patch, ANY + +import boto3 import pytest from sagemaker.jumpstart.constants import ( @@ -17,7 +20,6 @@ get_prototype_manifest, get_prototype_model_spec, ) -from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.notebook_utils import ( _generate_jumpstart_model_versions, @@ -46,7 +48,7 @@ def test_list_jumpstart_scripts( ) patched_generate_jumpstart_models.side_effect = _generate_jumpstart_model_versions patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( - get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() + get_prototype_model_spec(None, "pytorch-ic-mobilenet-v2").to_json() ) assert list_jumpstart_scripts() == sorted(["inference", "training"]) @@ -107,7 +109,6 @@ def test_list_jumpstart_tasks( assert list_jumpstart_tasks() == sorted( [ "classification", - "eqa", "ic", "semseg", "spc", @@ -181,7 +182,6 @@ def test_list_jumpstart_frameworks( "huggingface", "lightgbm", "mxnet", - "pytorch", "sklearn", "xgboost", ] @@ -218,7 +218,7 @@ def test_list_jumpstart_models_simple_case( ("huggingface-spc-bert-base-cased", "1.0.0"), ("lightgbm-classification-model", "1.0.0"), ("mxnet-semseg-fcn-resnet50-ade", "1.0.0"), - ("pytorch-eqa-bert-base-cased", "1.0.0"), + ("pytorch-ic-mobilenet-v2", "1.0.0"), ("sklearn-classification-linear", "1.0.0"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"), ("xgboost-classification-model", "1.0.0"), @@ -227,62 +227,61 @@ def test_list_jumpstart_models_simple_case( patched_get_manifest.assert_called() patched_get_model_specs.assert_not_called() - @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), - reason="Contact JumpStart team to fix flaky test.", - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") + @pytest.mark.flaky(reruns=5, reruns_delay=1) def test_list_jumpstart_models_script_filter( self, patched_read_s3_file: Mock, patched_get_manifest: Mock ): patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( - get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() + get_prototype_model_spec(None, "pytorch-ic-mobilenet-v2").to_json() ) patched_get_manifest.side_effect = ( - lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region) ) manifest_length = len(get_prototype_manifest()) vals = [True, False] for val in vals: - kwargs = {"filter": f"training_supported == {val}"} + kwargs = {"filter": And(f"training_supported == {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) - assert patched_read_s3_file.call_count == manifest_length - patched_get_manifest.assert_called_once() + assert patched_read_s3_file.call_count == 2 * manifest_length + assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - kwargs = {"filter": f"training_supported != {val}"} + kwargs = {"filter": And(f"training_supported != {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - - kwargs = {"filter": f"training_supported in {vals}", "list_versions": True} + kwargs = { + "filter": And(f"training_supported != {val}", "model_type is open_weights"), + "list_versions": True, + } assert list_jumpstart_models(**kwargs) == [ ("catboost-classification-model", "1.0.0"), ("huggingface-spc-bert-base-cased", "1.0.0"), ("lightgbm-classification-model", "1.0.0"), ("mxnet-semseg-fcn-resnet50-ade", "1.0.0"), - ("pytorch-eqa-bert-base-cased", "1.0.0"), + ("pytorch-ic-mobilenet-v2", "1.0.0"), ("sklearn-classification-linear", "1.0.0"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"), ("xgboost-classification-model", "1.0.0"), ] - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - kwargs = {"filter": f"training_supported not in {vals}"} + kwargs = {"filter": And(f"training_supported not in {vals}", "model_type is open_weights")} models = list_jumpstart_models(**kwargs) assert [] == models - assert patched_read_s3_file.call_count == manifest_length + assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -325,7 +324,7 @@ def test_list_jumpstart_models_task_filter( ("huggingface-spc-bert-base-cased", "1.0.0"), ("lightgbm-classification-model", "1.0.0"), ("mxnet-semseg-fcn-resnet50-ade", "1.0.0"), - ("pytorch-eqa-bert-base-cased", "1.0.0"), + ("pytorch-ic-mobilenet-v2", "1.0.0"), ("sklearn-classification-linear", "1.0.0"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0"), ("xgboost-classification-model", "1.0.0"), @@ -348,7 +347,7 @@ def test_list_jumpstart_models_framework_filter( self, patched_read_s3_file: Mock, patched_get_manifest: Mock ): patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( - get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() + get_prototype_model_spec(None, "pytorch-ic-mobilenet-v2").to_json() ) patched_get_manifest.side_effect = lambda region, *args, **kwargs: get_prototype_manifest( region @@ -386,7 +385,7 @@ def test_list_jumpstart_models_framework_filter( ("huggingface-spc-bert-base-cased", "1.0.0"), ("lightgbm-classification-model", "1.0.0"), ("mxnet-semseg-fcn-resnet50-ade", "1.0.0"), - ("pytorch-eqa-bert-base-cased", "1.0.0"), + ("pytorch-ic-mobilenet-v2", "1.0.0"), ("sklearn-classification-linear", "1.0.0"), ("xgboost-classification-model", "1.0.0"), ] @@ -480,10 +479,10 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): ("mxnet-semseg-fcn-resnet50-ade", "2.5.1"), ("mxnet-semseg-fcn-resnet50-ade", "1.300.0"), ("mxnet-semseg-fcn-resnet50-ade", "1.4.0"), - ("pytorch-eqa-bert-base-cased", "2.400.0"), - ("pytorch-eqa-bert-base-cased", "2.5.1"), - ("pytorch-eqa-bert-base-cased", "1.300.0"), - ("pytorch-eqa-bert-base-cased", "1.4.0"), + ("pytorch-ic-mobilenet-v2", "2.400.0"), + ("pytorch-ic-mobilenet-v2", "2.5.1"), + ("pytorch-ic-mobilenet-v2", "1.300.0"), + ("pytorch-ic-mobilenet-v2", "1.4.0"), ("sklearn-classification-linear", "2.400.0"), ("sklearn-classification-linear", "2.5.1"), ("sklearn-classification-linear", "1.300.0"), @@ -509,7 +508,7 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): ("huggingface-spc-bert-base-cased", "2.400.0"), ("lightgbm-classification-model", "2.400.0"), ("mxnet-semseg-fcn-resnet50-ade", "2.400.0"), - ("pytorch-eqa-bert-base-cased", "2.400.0"), + ("pytorch-ic-mobilenet-v2", "2.400.0"), ("sklearn-classification-linear", "2.400.0"), ("tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "2.400.0"), ("xgboost-classification-model", "2.400.0"), @@ -519,7 +518,7 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): ) == list_jumpstart_models(list_versions=True) @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), + datetime.datetime.now() < datetime.datetime(year=2024, month=8, day=1), reason="Contact JumpStart team to fix flaky test.", ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @@ -535,24 +534,27 @@ def test_list_jumpstart_models_vulnerable_models( ) def vulnerable_inference_model_spec(bucket, key, *args, **kwargs) -> str: - spec = get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased") + spec = get_prototype_model_spec(None, "catboost-classification-model") spec.inference_vulnerable = True return json.dumps(spec.to_json()) def vulnerable_training_model_spec(bucket, key, *args, **kwargs): - spec = get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased") + spec = get_prototype_model_spec(None, "catboost-classification-model") spec.training_vulnerable = True return json.dumps(spec.to_json()) patched_read_s3_file.side_effect = vulnerable_inference_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) - num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) assert [] == list_jumpstart_models( - And("inference_vulnerable is false", "training_vulnerable is false") + And( + "inference_vulnerable is false", + "training_vulnerable is false", + "model_type is open_weights", + ) ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -561,10 +563,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): patched_read_s3_file.side_effect = vulnerable_training_model_spec assert [] == list_jumpstart_models( - And("inference_vulnerable is false", "training_vulnerable is false") + And( + "inference_vulnerable is false", + "training_vulnerable is false", + "model_type is open_weights", + ) ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -574,12 +580,9 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): assert patched_read_s3_file.call_count == 0 - @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), - reason="Contact JumpStart team to fix flaky test.", - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") + @pytest.mark.flaky(reruns=5, reruns_delay=1) def test_list_jumpstart_models_deprecated_models( self, patched_read_s3_file: Mock, @@ -591,17 +594,18 @@ def test_list_jumpstart_models_deprecated_models( ) def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: - spec = get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased") + spec = get_prototype_model_spec(None, "pytorch-ic-mobilenet-v2") spec.deprecated = True return json.dumps(spec.to_json()) patched_read_s3_file.side_effect = deprecated_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) - num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) - assert [] == list_jumpstart_models("deprecated equals false") + assert [] == list_jumpstart_models( + And("deprecated equals false", "model_type is open_weights") + ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -628,7 +632,7 @@ def test_list_jumpstart_models_no_versions( "huggingface-spc-bert-base-cased", "lightgbm-classification-model", "mxnet-semseg-fcn-resnet50-ade", - "pytorch-eqa-bert-base-cased", + "pytorch-ic-mobilenet-v2", "sklearn-classification-linear", "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "xgboost-classification-model", @@ -662,7 +666,7 @@ def test_list_jumpstart_proprietary_models( "huggingface-spc-bert-base-cased", "lightgbm-classification-model", "mxnet-semseg-fcn-resnet50-ade", - "pytorch-eqa-bert-base-cased", + "pytorch-ic-mobilenet-v2", "sklearn-classification-linear", "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "xgboost-classification-model", @@ -684,7 +688,7 @@ def test_list_jumpstart_models_complex_queries( patched_get_manifest: Mock, ): patched_read_s3_file.side_effect = lambda *args, **kwargs: json.dumps( - get_prototype_model_spec(None, "pytorch-eqa-bert-base-cased").to_json() + get_prototype_model_spec(None, "pytorch-ic-mobilenet-v2").to_json() ) patched_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) @@ -701,7 +705,7 @@ def test_list_jumpstart_models_complex_queries( "false", "unknown", ) - ) == ["tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1"] + ) == ["pytorch-ic-mobilenet-v2", "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1"] assert list_jumpstart_models( Or( @@ -710,7 +714,7 @@ def test_list_jumpstart_models_complex_queries( "framework==tensorflow", Identity( And( - And("incremental_training_supported==falSE"), + And("incremental_training_supported==true"), "true", Or("unknown", "version equals 1.0.0"), ) @@ -753,16 +757,19 @@ def test_get_model_url( patched_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) + mock_client = boto3.client("s3") + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) - model_id, version = "xgboost-classification-model", "1.0.0" - assert "https://xgboost.readthedocs.io/en/latest/" == get_model_url(model_id, version) + model_id, version = "xgboost-classification-model", "*" + assert "https://xgboost.readthedocs.io/en/release_1.7.0/" == get_model_url(model_id, version) model_id, version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0" assert "https://tfhub.dev/google/bit/m-r101x1/ilsvrc2012_classification/1" == get_model_url( model_id, version ) - model_id, version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "1.0.0" + model_id, version = "tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1", "*" patched_get_model_specs.reset_mock() patched_get_model_specs.side_effect = lambda *largs, **kwargs: get_prototype_model_spec( @@ -771,12 +778,14 @@ def test_get_model_url( **{key: value for key, value in kwargs.items() if key != "region"}, ) - get_model_url(model_id, version, region="us-west-2") + get_model_url(model_id, version, region="us-west-2", sagemaker_session=mock_session) patched_get_model_specs.assert_called_once_with( model_id=model_id, version=version, region="us-west-2", - s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/jumpstart/test_payload_utils.py b/tests/unit/sagemaker/jumpstart/test_payload_utils.py index afc955e2f3..3c339c9b95 100644 --- a/tests/unit/sagemaker/jumpstart/test_payload_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_payload_utils.py @@ -32,10 +32,36 @@ def test_construct_payload(self, patched_get_model_specs): region = "us-west-2" constructed_payload_body = _construct_payload( - prompt="kobebryant", - model_id=model_id, - model_version="*", - region=region, + prompt="kobebryant", model_id=model_id, model_version="*", region=region + ).body + + self.assertEqual( + { + "hello": {"prompt": "kobebryant"}, + "seed": 43, + }, + constructed_payload_body, + ) + + # Unsupported model + self.assertIsNone( + _construct_payload( + prompt="blah", + model_id="default_payloads", + model_version="*", + region=region, + ) + ) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + def test_construct_payload_with_specific_alias(self, patched_get_model_specs): + patched_get_model_specs.side_effect = get_special_model_spec + + model_id = "prompt-key" + region = "us-west-2" + + constructed_payload_body = _construct_payload( + prompt="kobebryant", model_id=model_id, model_version="*", region=region, alias="Dog" ).body self.assertEqual( diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 52f28f2da1..8368f72d58 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -18,7 +18,7 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, get_spec_from_base_spec -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support( @@ -52,7 +52,7 @@ def test_jumpstart_predictor_support( assert js_predictor.accept == MIMEType.JSON -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_proprietary_predictor_support( @@ -91,13 +91,13 @@ def test_proprietary_predictor_support( @patch("sagemaker.predictor.Predictor") @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( patched_get_model_specs, patched_verify_model_region_and_return_specs, - patched_get_jumpstart_model_id_version_from_endpoint, + patched_get_model_info_from_endpoint, patched_get_default_predictor, patched_predictor, ): @@ -105,19 +105,19 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec - patched_get_jumpstart_model_id_version_from_endpoint.return_value = ( + patched_get_model_info_from_endpoint.return_value = ( "predictor-specs-model", "1.2.3", None, + None, + None, ) mock_session = Mock() predictor.retrieve_default(endpoint_name="blah", sagemaker_session=mock_session) - patched_get_jumpstart_model_id_version_from_endpoint.assert_called_once_with( - "blah", None, mock_session - ) + patched_get_model_info_from_endpoint.assert_called_once_with("blah", None, mock_session) patched_get_default_predictor.assert_called_once_with( predictor=patched_predictor.return_value, @@ -128,11 +128,13 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_vulnerable_model=False, sagemaker_session=mock_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + config_name=None, + hub_arn=None, ) @patch("sagemaker.predictor.get_default_predictor") -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( @@ -159,7 +161,8 @@ def test_jumpstart_predictor_support_no_model_id_supplied_sad_case( patched_get_default_predictor.assert_not_called() -@patch("sagemaker.predictor.get_model_id_version_from_endpoint") +@patch("sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}) +@patch("sagemaker.predictor.get_model_info_from_endpoint") @patch("sagemaker.jumpstart.payload_utils.JumpStartS3PayloadAccessor.get_object_cached") @patch("sagemaker.jumpstart.model.validate_model_id_and_get_type") @patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") @@ -169,7 +172,8 @@ def test_jumpstart_serializable_payload_with_predictor( patched_verify_model_region_and_return_specs, patched_validate_model_id_and_get_type, patched_get_object_cached, - patched_get_model_id_version_from_endpoint, + patched_get_model_info_from_endpoint, + patched_get_jumpstart_configs, ): patched_get_object_cached.return_value = base64.b64decode("encodedimage") @@ -179,7 +183,7 @@ def test_jumpstart_serializable_payload_with_predictor( patched_get_model_specs.side_effect = get_special_model_spec model_id, model_version = "default_payloads", "*" - patched_get_model_id_version_from_endpoint.return_value = model_id, model_version, None + patched_get_model_info_from_endpoint.return_value = model_id, model_version, None js_predictor = predictor.retrieve_default( endpoint_name="blah", model_id=model_id, model_version=model_version diff --git a/tests/unit/sagemaker/jumpstart/test_session_utils.py b/tests/unit/sagemaker/jumpstart/test_session_utils.py index 76ad50f31c..ce06a189bd 100644 --- a/tests/unit/sagemaker/jumpstart/test_session_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_session_utils.py @@ -4,167 +4,202 @@ import pytest from sagemaker.jumpstart.session_utils import ( - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name, - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name, - _get_model_id_version_from_model_based_endpoint, - get_model_id_version_from_endpoint, - get_model_id_version_from_training_job, + _get_model_info_from_inference_component_endpoint_with_inference_component_name, + _get_model_info_from_inference_component_endpoint_without_inference_component_name, + _get_model_info_from_model_based_endpoint, + get_model_info_from_endpoint, + get_model_info_from_training_job, ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_training_job_happy_case( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) - retval = get_model_id_version_from_training_job("bLaH", sagemaker_session=mock_sm_session) + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_training_job_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_training_job_config_name( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + "model_id", + "model_version", + None, + "training_config_name", + ) + + retval = get_model_info_from_training_job("bLaH", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", None, "training_config_name") + + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/bLaH", mock_sm_session + ) + + +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_training_job_no_model_id_inferred( + mock_get_jumpstart_model_info_from_resource_arn, +): + mock_sm_session = Mock() + mock_sm_session.boto_region_name = "us-west-2" + mock_sm_session.account_id = Mock(return_value="123456789012") + + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( None, None, ) with pytest.raises(ValueError): - get_model_id_version_from_training_job("blah", sagemaker_session=mock_sm_session) + get_model_info_from_training_job("blah", sagemaker_session=mock_sm_session) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_model_based_endpoint_happy_case( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) - retval = _get_model_id_version_from_model_based_endpoint( + retval = _get_model_info_from_model_based_endpoint( "bLaH", inference_component_name=None, sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:endpoint/blah", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_inference_component_supplied( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_model_based_endpoint_inference_component_supplied( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_model_based_endpoint_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_model_based_endpoint_no_model_id_inferred( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + None, None, None, ) with pytest.raises(ValueError): - _get_model_id_version_from_model_based_endpoint( + _get_model_info_from_model_based_endpoint( "blah", inference_component_name="some-name", sagemaker_session=mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_happy_case( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_happy_case( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( "model_id", "model_version", + None, + None, ) - retval = _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + retval = _get_model_info_from_inference_component_endpoint_with_inference_component_name( "bLaH", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version") + assert retval == ("model_id", "model_version", None, None) - mock_get_jumpstart_model_id_version_from_resource_arn.assert_called_once_with( + mock_get_jumpstart_model_info_from_resource_arn.assert_called_once_with( "arn:aws:sagemaker:us-west-2:123456789012:inference-component/bLaH", mock_sm_session ) -@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_id_version_from_resource_arn") -def test_get_model_id_version_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( - mock_get_jumpstart_model_id_version_from_resource_arn, +@patch("sagemaker.jumpstart.session_utils.get_jumpstart_model_info_from_resource_arn") +def test_get_model_info_from_inference_component_endpoint_with_inference_component_name_no_model_id_inferred( + mock_get_jumpstart_model_info_from_resource_arn, ): mock_sm_session = Mock() mock_sm_session.boto_region_name = "us-west-2" mock_sm_session.account_id = Mock(return_value="123456789012") - mock_get_jumpstart_model_id_version_from_resource_arn.return_value = ( + mock_get_jumpstart_model_info_from_resource_arn.return_value = ( + None, + None, None, None, ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_with_inference_component_name( + _get_model_info_from_inference_component_endpoint_with_inference_component_name( "blah", sagemaker_session=mock_sm_session ) @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_inference_component_name_happy_case( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_inference_component_name_happy_case( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -172,10 +207,8 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc return_value=["icname"] ) - retval = ( - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( - "blahblah", mock_sm_session - ) + retval = _get_model_info_from_inference_component_endpoint_without_inference_component_name( + "blahblah", mock_sm_session ) assert retval == ("model_id", "model_version", "icname") @@ -185,14 +218,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_inferenc @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_inference_component_endpoint_without_ic_name_no_ic_for_endpoint( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -200,7 +233,7 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ return_value=[] ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -210,14 +243,14 @@ def test_get_model_id_version_from_inference_component_endpoint_without_ic_name_ @patch( - "sagemaker.jumpstart.session_utils._get_model_id" - "_version_from_inference_component_endpoint_with_inference_component_name" + "sagemaker.jumpstart.session_utils._get_model" + "_info_from_inference_component_endpoint_with_inference_component_name" ) def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_multiple_ics_for_endpoint( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", ) @@ -227,7 +260,7 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) with pytest.raises(ValueError): - _get_model_id_version_from_inference_component_endpoint_without_inference_component_name( + _get_model_info_from_inference_component_endpoint_without_inference_component_name( "blahblah", mock_sm_session ) @@ -236,67 +269,119 @@ def test_get_model_id_version_from_ic_endpoint_without_inference_component_name_ ) -@patch("sagemaker.jumpstart.session_utils._get_model_id_version_from_model_based_endpoint") -def test_get_model_id_version_from_endpoint_non_inference_component_endpoint( - mock_get_model_id_version_from_model_based_endpoint, +@patch("sagemaker.jumpstart.session_utils._get_model_info_from_model_based_endpoint") +def test_get_model_info_from_endpoint_non_inference_component_endpoint( + mock_get_model_info_from_model_based_endpoint, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = False - mock_get_model_id_version_from_model_based_endpoint.return_value = ( + mock_get_model_info_from_model_based_endpoint.return_value = ( "model_id", "model_version", + None, + None, ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", None) - mock_get_model_id_version_from_model_based_endpoint.assert_called_once_with( + assert retval == ("model_id", "model_version", None, None, None) + mock_get_model_info_from_model_based_endpoint.assert_called_once_with( "blah", None, mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_called_once_with("blah") @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_" "component_endpoint_with_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_with_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.return_value = ( "model_id", "model_version", + None, + None, ) - retval = get_model_id_version_from_endpoint( + retval = get_model_info_from_endpoint( "blah", inference_component_name="icname", sagemaker_session=mock_sm_session ) - assert retval == ("model_id", "model_version", "icname") - mock_get_model_id_version_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( + assert retval == ("model_id", "model_version", "icname", None, None) + mock_get_model_info_from_inference_component_endpoint_with_inference_component_name.assert_called_once_with( "icname", mock_sm_session ) mock_sm_session.is_inference_component_based_endpoint.assert_not_called() @patch( - "sagemaker.jumpstart.session_utils._get_model_id_version_from_inference_component_" + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" + "endpoint_without_inference_component_name" +) +def test_get_model_info_from_endpoint_inference_component_endpoint_without_inference_component_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, +): + mock_sm_session = Mock() + mock_sm_session.is_inference_component_based_endpoint.return_value = True + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + "model_id", + "model_version", + None, + None, + "inferred-icname", + ) + + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "inferred-icname", None, None) + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + + +@patch( + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" "endpoint_without_inference_component_name" ) -def test_get_model_id_version_from_endpoint_inference_component_endpoint_without_inference_component_name( - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name, +def test_get_model_info_from_endpoint_inference_component_endpoint_with_inference_config_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, ): mock_sm_session = Mock() mock_sm_session.is_inference_component_based_endpoint.return_value = True - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.return_value = ( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( "model_id", "model_version", + "inference_config_name", + None, + "inferred-icname", + ) + + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) + + assert retval == ("model_id", "model_version", "inferred-icname", "inference_config_name", None) + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + + +@patch( + "sagemaker.jumpstart.session_utils._get_model_info_from_inference_component_" + "endpoint_without_inference_component_name" +) +def test_get_model_info_from_endpoint_inference_component_endpoint_with_training_config_name( + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name, +): + mock_sm_session = Mock() + mock_sm_session.is_inference_component_based_endpoint.return_value = True + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.return_value = ( + "model_id", + "model_version", + None, + "training_config_name", "inferred-icname", ) - retval = get_model_id_version_from_endpoint("blah", sagemaker_session=mock_sm_session) + retval = get_model_info_from_endpoint("blah", sagemaker_session=mock_sm_session) - assert retval == ("model_id", "model_version", "inferred-icname") - mock_get_model_id_version_from_inference_component_endpoint_without_inference_component_name.assert_called_once() + assert retval == ("model_id", "model_version", "inferred-icname", None, "training_config_name") + mock_get_model_info_from_inference_component_endpoint_without_inference_component_name.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 5ca01c3c52..03a85fee44 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -12,25 +12,35 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy +from unittest import TestCase import pytest from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import ( JumpStartBenchmarkStat, JumpStartECRSpecs, + JumpStartEnvironmentVariable, JumpStartHyperparameter, JumpStartInstanceTypeVariants, JumpStartModelSpecs, JumpStartModelHeader, JumpStartConfigComponent, + DeploymentConfigMetadata, + JumpStartModelInitKwargs, + S3DataSource, ) +from sagemaker.utils import S3_PREFIX from tests.unit.sagemaker.jumpstart.constants import ( BASE_SPEC, + BASE_HOSTING_ADDITIONAL_DATA_SOURCES, INFERENCE_CONFIG_RANKINGS, INFERENCE_CONFIGS, TRAINING_CONFIG_RANKINGS, TRAINING_CONFIGS, + INIT_KWARGS, ) +from unittest.mock import Mock + INSTANCE_TYPE_VARIANT = JumpStartInstanceTypeVariants( { "regional_aliases": { @@ -109,7 +119,7 @@ "g4": { "regional_properties": {"image_uri": "$gpu_image_uri"}, "properties": { - "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + "training_artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" }, }, "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, @@ -185,7 +195,7 @@ }, "p9": { "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": {"artifact_key": "do/re/mi"}, + "properties": {"training_artifact_key": "do/re/mi"}, }, "m2": { "regional_properties": {"image_uri": "$cpu_image_uri"}, @@ -264,13 +274,13 @@ "ml.p9.12xlarge": { "properties": { "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "artifact_key": "you/not/entertained", + "training_artifact_key": "you/not/entertained", } }, "g6": { "properties": { "environment_variables": {"BLAH": "4"}, - "artifact_key": "path/to/training/artifact.tar.gz", + "training_artifact_key": "path/to/training/artifact.tar.gz", "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", } }, @@ -321,14 +331,67 @@ def test_jumpstart_model_header(): assert header1 == header3 -def test_use_training_model_artifact(): - specs1 = JumpStartModelSpecs(BASE_SPEC) - assert specs1.use_training_model_artifact() - specs1.gated_bucket = True - assert not specs1.use_training_model_artifact() - specs1.gated_bucket = False - specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"} - assert not specs1.use_training_model_artifact() +class TestUseTrainingModelArtifact: + @pytest.fixture + def mock_specs(self): + specs = Mock(spec=JumpStartModelSpecs) + specs.training_instance_type_variants = Mock() + specs.supported_training_instance_types = ["ml.p3.2xlarge", "ml.g4dn.xlarge"] + specs.training_model_package_artifact_uris = {} + specs.training_artifact_key = None + return specs + + def test_use_training_model_artifact_with_env_var(self, mock_specs): + """Test when instance type variants have env var values.""" + mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.side_effect = [ + "some-value", + None, + ] + + result = JumpStartModelSpecs.use_training_model_artifact(mock_specs) + + assert result is False + mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.assert_any_call( + "ml.p3.2xlarge" + ) + + def test_use_training_model_artifact_with_package_uris(self, mock_specs): + """Test when model has training package artifact URIs.""" + mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = ( + None + ) + mock_specs.training_model_package_artifact_uris = { + "ml.p3.2xlarge": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/" + "llama2-13b-e155a2e0347b323fb882f1875851c5d3" + } + + result = JumpStartModelSpecs.use_training_model_artifact(mock_specs) + + assert result is False + + def test_use_training_model_artifact_with_artifact_key(self, mock_specs): + """Test when model has training artifact key.""" + mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = ( + None + ) + mock_specs.training_model_package_artifact_uris = {} + mock_specs.training_artifact_key = "some-key" + + result = JumpStartModelSpecs.use_training_model_artifact(mock_specs) + + assert result is True + + def test_use_training_model_artifact_without_artifact_key(self, mock_specs): + """Test when model has no training artifact key.""" + mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = ( + None + ) + mock_specs.training_model_package_artifact_uris = {} + mock_specs.training_artifact_key = None + + result = JumpStartModelSpecs.use_training_model_artifact(mock_specs) + + assert result is False def test_jumpstart_model_specs(): @@ -336,62 +399,87 @@ def test_jumpstart_model_specs(): specs1 = JumpStartModelSpecs(BASE_SPEC) assert specs1.model_id == "pytorch-ic-mobilenet-v2" - assert specs1.version == "1.0.0" - assert specs1.min_sdk_version == "2.49.0" + assert specs1.version == "3.0.6" + assert specs1.min_sdk_version == "2.189.0" assert specs1.training_supported assert specs1.incremental_training_supported assert specs1.hosting_ecr_specs == JumpStartECRSpecs( { "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "framework_version": "1.10.0", + "py_version": "py38", } ) assert specs1.training_ecr_specs == JumpStartECRSpecs( { "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "framework_version": "1.10.0", + "py_version": "py38", } ) - assert specs1.hosting_artifact_key == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" - assert specs1.training_artifact_key == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" + assert ( + specs1.hosting_artifact_key + == "pytorch-ic/pytorch-ic-mobilenet-v2/artifacts/inference/v2.0.0/" + ) + assert ( + specs1.training_artifact_key + == "pytorch-training/v2.0.0/train-pytorch-ic-mobilenet-v2.tar.gz" + ) assert ( specs1.hosting_script_key - == "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz" + == "source-directory-tarballs/pytorch/inference/ic/v2.0.0/sourcedir.tar.gz" ) assert ( specs1.training_script_key - == "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz" + == "source-directory-tarballs/pytorch/transfer_learning/ic/v2.3.0/sourcedir.tar.gz" ) + assert specs1.default_training_dataset_key == "training-datasets/tf_flowers/" assert specs1.hyperparameters == [ + JumpStartHyperparameter( + { + "name": "train_only_top_layer", + "type": "text", + "options": ["True", "False"], + "default": "True", + "scope": "algorithm", + } + ), JumpStartHyperparameter( { "name": "epochs", "type": "int", - "default": 3, + "default": 5, + "scope": "algorithm", "min": 1, "max": 1000, - "scope": "algorithm", } ), JumpStartHyperparameter( { - "name": "adam-learning-rate", + "name": "learning_rate", "type": "float", - "default": 0.05, + "default": 0.001, + "scope": "algorithm", "min": 1e-08, "max": 1, - "scope": "algorithm", } ), JumpStartHyperparameter( { - "name": "batch-size", + "name": "batch_size", "type": "int", "default": 4, + "scope": "algorithm", "min": 1, "max": 1024, + } + ), + JumpStartHyperparameter( + { + "name": "reinitialize_top_layer", + "type": "text", + "options": ["Auto", "True", "False"], + "default": "Auto", "scope": "algorithm", } ), @@ -421,6 +509,7 @@ def test_jumpstart_model_specs(): ), ] + print(specs1.to_json()) assert specs1.to_json() == BASE_SPEC BASE_SPEC_COPY = copy.deepcopy(BASE_SPEC) @@ -432,6 +521,54 @@ def test_jumpstart_model_specs(): assert specs3 == specs1 +class TestS3DataSource(TestCase): + def setUp(self): + self.s3_data_source = S3DataSource( + { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/model/artifact/", + "model_access_config": {"accept_eula": False}, + } + ) + + def test_set_bucket_with_valid_s3_uri(self): + self.s3_data_source.set_bucket("my-bucket") + self.assertEqual(self.s3_data_source.s3_uri, f"{S3_PREFIX}my-bucket/key/to/model/artifact/") + + def test_set_bucket_with_existing_s3_uri(self): + self.s3_data_source.s3_uri = "s3://my-bucket/key/to/model/artifact/" + self.s3_data_source.set_bucket("random-new-bucket") + assert self.s3_data_source.s3_uri == "s3://random-new-bucket/key/to/model/artifact/" + + def test_set_bucket_with_existing_s3_uri_empty_bucket(self): + self.s3_data_source.s3_uri = "s3://my-bucket" + self.s3_data_source.set_bucket("random-new-bucket") + assert self.s3_data_source.s3_uri == "s3://random-new-bucket" + + def test_set_bucket_with_existing_s3_uri_empty(self): + self.s3_data_source.s3_uri = "s3://" + self.s3_data_source.set_bucket("random-new-bucket") + assert self.s3_data_source.s3_uri == "s3://random-new-bucket" + + +def test_get_speculative_decoding_s3_data_sources(): + specs = JumpStartModelSpecs({**BASE_SPEC, **BASE_HOSTING_ADDITIONAL_DATA_SOURCES}) + assert ( + specs.get_speculative_decoding_s3_data_sources() + == specs.hosting_additional_data_sources.speculative_decoding + ) + + +def test_get_additional_s3_data_sources(): + specs = JumpStartModelSpecs({**BASE_SPEC, **BASE_HOSTING_ADDITIONAL_DATA_SOURCES}) + data_sources = [ + *specs.hosting_additional_data_sources.speculative_decoding, + *specs.hosting_additional_data_sources.scripts, + ] + assert specs.get_additional_s3_data_sources() == data_sources + + def test_jumpstart_image_uri_instance_variants(): assert ( @@ -871,27 +1008,35 @@ def test_jumpstart_hosting_prepacked_artifact_key_instance_variants(): def test_jumpstart_training_artifact_key_instance_variants(): assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.g6.xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.g6.xlarge" + ) == "path/to/training/artifact.tar.gz" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.g4.9xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.g4.9xlarge" + ) == "path/to/prepacked/training/artifact/prefix/number2/" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.p9.9xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.p9.9xlarge" + ) == "do/re/mi" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key(instance_type="ml.p9.12xlarge") + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( + instance_type="ml.p9.12xlarge" + ) == "you/not/entertained" ) assert ( - INSTANCE_TYPE_VARIANT.get_instance_specific_artifact_key( + INSTANCE_TYPE_VARIANT.get_instance_specific_training_artifact_key( instance_type="ml.g9dsfsdfs.12xlarge" ) is None @@ -924,13 +1069,15 @@ def test_inference_configs_parsing(): "neuron-inference", "neuron-budget", "gpu-inference", + "gpu-inference-model-package", "gpu-inference-budget", + "gpu-accelerated", ] # Non-overrided fields in top config assert specs1.model_id == "pytorch-ic-mobilenet-v2" - assert specs1.version == "1.0.0" - assert specs1.min_sdk_version == "2.49.0" + assert specs1.version == "3.0.6" + assert specs1.min_sdk_version == "2.189.0" assert specs1.training_supported assert specs1.incremental_training_supported assert specs1.hosting_ecr_specs == JumpStartECRSpecs( @@ -943,51 +1090,72 @@ def test_inference_configs_parsing(): assert specs1.training_ecr_specs == JumpStartECRSpecs( { "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "framework_version": "1.10.0", + "py_version": "py38", } ) assert ( specs1.hosting_artifact_key == "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/" ) - assert specs1.training_artifact_key == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" + assert ( + specs1.training_artifact_key + == "pytorch-training/v2.0.0/train-pytorch-ic-mobilenet-v2.tar.gz" + ) assert ( specs1.hosting_script_key - == "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz" + == "source-directory-tarballs/pytorch/inference/ic/v2.0.0/sourcedir.tar.gz" ) assert ( specs1.training_script_key - == "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz" + == "source-directory-tarballs/pytorch/transfer_learning/ic/v2.3.0/sourcedir.tar.gz" ) assert specs1.hyperparameters == [ + JumpStartHyperparameter( + { + "name": "train_only_top_layer", + "type": "text", + "options": ["True", "False"], + "default": "True", + "scope": "algorithm", + } + ), JumpStartHyperparameter( { "name": "epochs", "type": "int", - "default": 3, + "default": 5, + "scope": "algorithm", "min": 1, "max": 1000, - "scope": "algorithm", } ), JumpStartHyperparameter( { - "name": "adam-learning-rate", + "name": "learning_rate", "type": "float", - "default": 0.05, + "default": 0.001, + "scope": "algorithm", "min": 1e-08, "max": 1, - "scope": "algorithm", } ), JumpStartHyperparameter( { - "name": "batch-size", + "name": "batch_size", "type": "int", "default": 4, + "scope": "algorithm", "min": 1, "max": 1024, + } + ), + JumpStartHyperparameter( + { + "name": "reinitialize_top_layer", + "type": "text", + "options": ["Auto", "True", "False"], + "default": "Auto", "scope": "algorithm", } ), @@ -1016,6 +1184,80 @@ def test_inference_configs_parsing(): } ), ] + assert specs1.inference_environment_variables == [ + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + } + ), + JumpStartEnvironmentVariable( + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + } + ), + JumpStartEnvironmentVariable( + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + } + ), + ] # Overrided fields in top config assert specs1.supported_inference_instance_types == ["ml.inf2.xlarge", "ml.inf2.2xlarge"] @@ -1024,7 +1266,9 @@ def test_inference_configs_parsing(): assert config.benchmark_metrics == { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), ] } assert len(config.config_components) == 1 @@ -1043,7 +1287,7 @@ def test_inference_configs_parsing(): "regional_aliases": { "us-west-2": { "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" } }, "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, @@ -1052,6 +1296,28 @@ def test_inference_configs_parsing(): ) assert list(config.config_components.keys()) == ["neuron-inference"] + config = specs1.inference_configs.configs["gpu-inference-model-package"] + assert config.config_components["gpu-inference-model-package"] == JumpStartConfigComponent( + "gpu-inference-model-package", + { + "default_inference_instance_type": "ml.p2.xlarge", + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/" + "llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + }, + }, + ) + assert config.resolved_config.get("inference_environment_variables") == [] + + spec = { + **BASE_SPEC, + **INFERENCE_CONFIGS, + **INFERENCE_CONFIG_RANKINGS, + "unrecognized-field": "blah", # New fields in base metadata fields should be ignored + } + specs1 = JumpStartModelSpecs(spec) + def test_set_inference_configs(): spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} @@ -1062,7 +1328,9 @@ def test_set_inference_configs(): "neuron-inference", "neuron-budget", "gpu-inference", + "gpu-inference-model-package", "gpu-inference-budget", + "gpu-accelerated", ] with pytest.raises(ValueError) as error: @@ -1070,7 +1338,7 @@ def test_set_inference_configs(): assert "Cannot find Jumpstart config name invalid_name." "List of config names that is supported by the model: " "['neuron-inference', 'neuron-inference-budget', " - "'gpu-inference-budget', 'gpu-inference']" in str(error.value) + "'gpu-inference-budget', 'gpu-inference', 'gpu-inference-model-package']" in str(error.value) assert specs1.supported_inference_instance_types == ["ml.inf2.xlarge", "ml.inf2.2xlarge"] specs1.set_config("gpu-inference") @@ -1091,62 +1359,86 @@ def test_training_configs_parsing(): # Non-overrided fields in top config # By default training config is not applied to model spec assert specs1.model_id == "pytorch-ic-mobilenet-v2" - assert specs1.version == "1.0.0" - assert specs1.min_sdk_version == "2.49.0" + assert specs1.version == "3.0.6" + assert specs1.min_sdk_version == "2.189.0" assert specs1.training_supported assert specs1.incremental_training_supported assert specs1.hosting_ecr_specs == JumpStartECRSpecs( { "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "framework_version": "1.10.0", + "py_version": "py38", } ) assert specs1.training_ecr_specs == JumpStartECRSpecs( { "framework": "pytorch", - "framework_version": "1.5.0", - "py_version": "py3", + "framework_version": "1.10.0", + "py_version": "py38", } ) - assert specs1.hosting_artifact_key == "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" - assert specs1.training_artifact_key == "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" + assert ( + specs1.hosting_artifact_key + == "pytorch-ic/pytorch-ic-mobilenet-v2/artifacts/inference/v2.0.0/" + ) + assert ( + specs1.training_artifact_key + == "pytorch-training/v2.0.0/train-pytorch-ic-mobilenet-v2.tar.gz" + ) assert ( specs1.hosting_script_key - == "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz" + == "source-directory-tarballs/pytorch/inference/ic/v2.0.0/sourcedir.tar.gz" ) assert ( specs1.training_script_key - == "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz" + == "source-directory-tarballs/pytorch/transfer_learning/ic/v2.3.0/sourcedir.tar.gz" ) assert specs1.hyperparameters == [ + JumpStartHyperparameter( + { + "name": "train_only_top_layer", + "type": "text", + "options": ["True", "False"], + "default": "True", + "scope": "algorithm", + } + ), JumpStartHyperparameter( { "name": "epochs", "type": "int", - "default": 3, + "default": 5, + "scope": "algorithm", "min": 1, "max": 1000, - "scope": "algorithm", } ), JumpStartHyperparameter( { - "name": "adam-learning-rate", + "name": "learning_rate", "type": "float", - "default": 0.05, + "default": 0.001, + "scope": "algorithm", "min": 1e-08, "max": 1, - "scope": "algorithm", } ), JumpStartHyperparameter( { - "name": "batch-size", + "name": "batch_size", "type": "int", "default": 4, + "scope": "algorithm", "min": 1, "max": 1024, + } + ), + JumpStartHyperparameter( + { + "name": "reinitialize_top_layer", + "type": "text", + "options": ["Auto", "True", "False"], + "default": "Auto", "scope": "algorithm", } ), @@ -1180,18 +1472,29 @@ def test_training_configs_parsing(): assert config.benchmark_metrics == { "ml.tr1n1.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), ], "ml.tr1n1.4xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "50", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ), ], } assert len(config.config_components) == 1 assert config.config_components["neuron-training"] == JumpStartConfigComponent( "neuron-training", { + "default_training_instance_type": "ml.trn1.2xlarge", "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, "training_instance_type_variants": { "regional_aliases": { "us-west-2": { @@ -1206,6 +1509,83 @@ def test_training_configs_parsing(): assert list(config.config_components.keys()) == ["neuron-training"] +def test_additional_model_data_source_parsing(): + accelerated_first_rankings = { + "inference_config_rankings": { + "overall": { + "description": "Overall rankings of configs", + "rankings": [ + "gpu-accelerated", + "neuron-inference", + "neuron-inference-budget", + "gpu-inference", + "gpu-inference-budget", + ], + } + } + } + spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **accelerated_first_rankings} + specs1 = JumpStartModelSpecs(spec) + + config = specs1.inference_configs.get_top_config_from_ranking() + + assert config.benchmark_metrics == { + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + ] + } + assert len(config.config_components) == 1 + assert config.config_components["gpu-accelerated"] == JumpStartConfigComponent( + "gpu-accelerated", + { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + "hosting_additional_data_sources": { + "speculative_decoding": [ + { + "channel_name": "draft_model_name", + "artifact_version": "1.2.1", + "s3_data_source": { + "compression_type": "None", + "model_access_config": {"accept_eula": False}, + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/draft/model/artifact/", + }, + } + ], + }, + }, + ) + assert list(config.config_components.keys()) == ["gpu-accelerated"] + assert config.resolved_config["hosting_additional_data_sources"] == { + "speculative_decoding": [ + { + "channel_name": "draft_model_name", + "artifact_version": "1.2.1", + "s3_data_source": { + "compression_type": "None", + "model_access_config": {"accept_eula": False}, + "s3_data_type": "S3Prefix", + "s3_uri": "key/to/draft/model/artifact/", + }, + } + ], + } + + def test_set_inference_config(): spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} specs1 = JumpStartModelSpecs(spec) @@ -1225,11 +1605,9 @@ def test_set_training_config(): specs1 = JumpStartModelSpecs(spec) assert specs1.supported_training_instance_types == [ - "ml.p3.2xlarge", - "ml.p2.xlarge", - "ml.g4dn.2xlarge", "ml.m5.xlarge", "ml.c5.2xlarge", + "ml.m4.xlarge", ] specs1.set_config("gpu-training-budget", scope=JumpStartScriptScope.TRAINING) @@ -1248,3 +1626,38 @@ def test_set_training_config(): with pytest.raises(ValueError) as error: specs1.set_config("invalid_name", scope="unknown scope") + + +def test_deployment_config_metadata(): + spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + specs = JumpStartModelSpecs(spec) + jumpstart_config = specs.inference_configs.get_top_config_from_ranking() + + deployment_config_metadata = DeploymentConfigMetadata( + jumpstart_config.config_name, + jumpstart_config, + JumpStartModelInitKwargs( + model_id=specs.model_id, + model_data=INIT_KWARGS.get("model_data"), + image_uri=INIT_KWARGS.get("image_uri"), + instance_type=INIT_KWARGS.get("instance_type"), + env=INIT_KWARGS.get("env"), + config_name=jumpstart_config.config_name, + ), + ) + + json_obj = deployment_config_metadata.to_json() + + assert isinstance(json_obj, dict) + assert json_obj["DeploymentConfigName"] == jumpstart_config.config_name + for key in json_obj["BenchmarkMetrics"]: + assert len(json_obj["BenchmarkMetrics"][key]) == len( + jumpstart_config.benchmark_metrics.get(key) + ) + assert json_obj["AccelerationConfigs"] == jumpstart_config.resolved_config.get( + "acceleration_configs" + ) + assert json_obj["DeploymentArgs"]["ImageUri"] == INIT_KWARGS.get("image_uri") + assert json_obj["DeploymentArgs"]["ModelData"] == INIT_KWARGS.get("model_data") + assert json_obj["DeploymentArgs"]["Environment"] == INIT_KWARGS.get("env") + assert json_obj["DeploymentArgs"]["InstanceType"] == INIT_KWARGS.get("instance_type") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index c1ea8abcb8..de9be1d51d 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -13,17 +13,22 @@ from __future__ import absolute_import import os from unittest import TestCase -from mock.mock import Mock, patch +from unittest.mock import call, mock_open, Mock, patch +import json +from botocore.exceptions import ClientError import pytest import boto3 import random +from sagemaker_core.shapes import ModelAccessConfig from sagemaker import session from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( + _load_region_config, DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE, + ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE, EXTRA_MODEL_ID_TAGS, EXTRA_MODEL_VERSION_TAGS, JUMPSTART_DEFAULT_REGION_NAME, @@ -31,7 +36,9 @@ JUMPSTART_LOGGER, JUMPSTART_REGION_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME, + NEO_DEFAULT_REGION_NAME, JumpStartScriptScope, + JUMPSTART_LAUNCHED_REGIONS, ) from functools import partial from sagemaker.jumpstart.enums import JumpStartTag, MIMEType, JumpStartModelType @@ -43,12 +50,15 @@ JumpStartBenchmarkStat, JumpStartModelHeader, JumpStartVersionedModelId, + JumpStartLaunchedRegionInfo, ) from tests.unit.sagemaker.jumpstart.utils import ( get_base_spec_with_prototype_configs, get_spec_from_base_spec, get_special_model_spec, get_prototype_manifest, + get_base_deployment_configs_metadata, + get_base_deployment_configs, ) from mock import MagicMock @@ -60,79 +70,95 @@ def random_jumpstart_s3_uri(key): return f"s3://{random.choice(list(JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET))}/{key}" -def test_get_jumpstart_content_bucket(): - bad_region = "bad_region" - assert bad_region not in JUMPSTART_REGION_NAME_SET - with pytest.raises(ValueError): - utils.get_jumpstart_content_bucket(bad_region) - - -def test_get_jumpstart_content_bucket_no_args(): - assert ( - utils.get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) - == utils.get_jumpstart_content_bucket() - ) - +class TestBucketUtils(TestCase): + def test_get_jumpstart_content_bucket(self): + bad_region = "bad_region" + assert bad_region not in JUMPSTART_REGION_NAME_SET + with pytest.raises(ValueError): + utils.get_jumpstart_content_bucket(bad_region) -def test_get_jumpstart_content_bucket_override(): - with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): - with patch("logging.Logger.info") as mocked_info_log: - random_region = "random_region" - assert "some-val" == utils.get_jumpstart_content_bucket(random_region) - mocked_info_log.assert_called_with("Using JumpStart bucket override: 'some-val'") + def test_get_jumpstart_content_bucket_no_args(self): + assert ( + utils.get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) + == utils.get_jumpstart_content_bucket() + ) + def test_get_jumpstart_content_bucket_override(self): + with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_jumpstart_content_bucket(random_region) + mocked_info_log.assert_called_with("Using JumpStart bucket override: 'some-val'") -def test_get_jumpstart_gated_content_bucket(): - bad_region = "bad_region" - assert bad_region not in JUMPSTART_REGION_NAME_SET - with pytest.raises(ValueError): - utils.get_jumpstart_gated_content_bucket(bad_region) + def test_get_jumpstart_gated_content_bucket(self): + bad_region = "bad_region" + assert bad_region not in JUMPSTART_REGION_NAME_SET + with pytest.raises(ValueError): + utils.get_jumpstart_gated_content_bucket(bad_region) + def test_get_jumpstart_gated_content_bucket_no_args(self): + assert ( + utils.get_jumpstart_gated_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) + == utils.get_jumpstart_gated_content_bucket() + ) -def test_get_jumpstart_gated_content_bucket_no_args(): - assert ( - utils.get_jumpstart_gated_content_bucket(JUMPSTART_DEFAULT_REGION_NAME) - == utils.get_jumpstart_gated_content_bucket() - ) + def test_get_jumpstart_gated_content_bucket_override(self): + with patch.dict( + os.environ, {ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE: "some-val"} + ): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_jumpstart_gated_content_bucket(random_region) + mocked_info_log.assert_called_once_with( + "Using JumpStart gated bucket override: 'some-val'" + ) + def test_get_jumpstart_launched_regions_message(self): -def test_get_jumpstart_gated_content_bucket_override(): - with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE: "some-val"}): - with patch("logging.Logger.info") as mocked_info_log: - random_region = "random_region" - assert "some-val" == utils.get_jumpstart_gated_content_bucket(random_region) - mocked_info_log.assert_called_once_with( - "Using JumpStart gated bucket override: 'some-val'" + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is not available in any region." ) + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region"}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in some_region region." + ) -def test_get_jumpstart_launched_regions_message(): + with patch( + "sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", + {"some_region1", "some_region2"}, + ): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in some_region1 and some_region2 regions." + ) - with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}): - assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is not available in any region." - ) + with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"a", "b", "c"}): + assert ( + utils.get_jumpstart_launched_regions_message() + == "JumpStart is available in a, b, and c regions." + ) - with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region"}): - assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is available in some_region region." - ) + def test_get_neo_content_bucket(self): + bad_region = "bad_region" + assert bad_region not in JUMPSTART_REGION_NAME_SET + with pytest.raises(ValueError): + utils.get_neo_content_bucket(bad_region) - with patch( - "sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"some_region1", "some_region2"} - ): + def test_get_neo_content_bucket_no_args(self): assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is available in some_region1 and some_region2 regions." + utils.get_neo_content_bucket(NEO_DEFAULT_REGION_NAME) == utils.get_neo_content_bucket() ) - with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {"a", "b", "c"}): - assert ( - utils.get_jumpstart_launched_regions_message() - == "JumpStart is available in a, b, and c regions." - ) + def test_get_neo_content_bucket_override(self): + with patch.dict(os.environ, {ENV_VARIABLE_NEO_CONTENT_BUCKET_OVERRIDE: "some-val"}): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_neo_content_bucket(random_region) + mocked_info_log.assert_called_with("Using Neo bucket override: 'some-val'") def test_get_formatted_manifest(): @@ -207,16 +233,16 @@ def test_is_jumpstart_model_uri(): assert utils.is_jumpstart_model_uri(random_jumpstart_s3_uri("random_key")) -def test_add_jumpstart_model_id_version_tags(): +def test_add_jumpstart_model_info_tags(): tags = None model_id = "model_id" version = "version" + inference_config_name = "inference_config_name" + training_config_name = "training_config_name" assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, @@ -228,9 +254,7 @@ def test_add_jumpstart_model_id_version_tags(): assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version_2"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "random key", "Value": "random_value"}, @@ -241,9 +265,7 @@ def test_add_jumpstart_model_id_version_tags(): {"Key": "random key", "Value": "random_value"}, {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, @@ -254,9 +276,7 @@ def test_add_jumpstart_model_id_version_tags(): assert [ {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id_2"}, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version - ) + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) tags = [ {"Key": "random key", "Value": "random_value"}, @@ -265,8 +285,58 @@ def test_add_jumpstart_model_id_version_tags(): version = None assert [ {"Key": "random key", "Value": "random_value"}, - ] == utils.add_jumpstart_model_id_version_tags( - tags=tags, model_id=model_id, model_version=version + ] == utils.add_jumpstart_model_info_tags(tags=tags, model_id=model_id, model_version=version) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + {"Key": "sagemaker-sdk:jumpstart-inference-config-name", "Value": "inference_config_name"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=inference_config_name, + scope=JumpStartScriptScope.INFERENCE, + ) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + {"Key": "sagemaker-sdk:jumpstart-training-config-name", "Value": "training_config_name"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=training_config_name, + scope=JumpStartScriptScope.TRAINING, + ) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + model_id = "model_id" + version = "version" + assert [ + {"Key": "random key", "Value": "random_value"}, + {"Key": "sagemaker-sdk:jumpstart-model-id", "Value": "model_id"}, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "version"}, + ] == utils.add_jumpstart_model_info_tags( + tags=tags, + model_id=model_id, + model_version=version, + config_name=training_config_name, ) @@ -1296,7 +1366,7 @@ def test_validate_model_id_and_get_type_invalid( ) assert ( - utils.validate_model_id_and_get_type("pytorch-eqa-bert-base-cased") + utils.validate_model_id_and_get_type("pytorch-ic-mobilenet-v2") == JumpStartModelType.OPEN_WEIGHTS ) mock_get_manifest.assert_called_with( @@ -1318,11 +1388,9 @@ def test_no_model_id_no_version_found(self): mock_sagemaker_session.list_tags = mock_list_tags mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] - self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None), + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1335,11 +1403,9 @@ def test_model_id_no_version_found(self): {"Key": JumpStartTag.MODEL_ID, "Value": "model_id"}, ] - self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id", None), + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id", None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1352,11 +1418,67 @@ def test_no_model_id_version_found(self): {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version"}, ] - self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, "model_version"), + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, "model_version", None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_no_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [{"Key": "blah", "Value": "blah1"}] + + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_inference_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name"}, + ] + + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, "config_name", None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_training_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "config_name"}, + ] + + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, "config_name"), + ) + mock_list_tags.assert_called_once_with("some-arn") + + def test_both_config_name_found(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "inference_config_name"}, + {"Key": JumpStartTag.TRAINING_CONFIG_NAME, "Value": "training_config_name"}, + ] + + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, "inference_config_name", "training_config_name"), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1370,11 +1492,9 @@ def test_model_id_version_found(self): {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version"}, ] - self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id", "model_version"), + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id", "model_version", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1390,11 +1510,9 @@ def test_multiple_model_id_versions_found(self): {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_2"}, ] - self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None), + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1410,11 +1528,9 @@ def test_multiple_model_id_versions_found_aliases_consistent(self): {"Key": random.choice(EXTRA_MODEL_VERSION_TAGS), "Value": "model_version_1"}, ] - self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - ("model_id_1", "model_version_1"), + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id_1", "model_version_1", None, None), ) mock_list_tags.assert_called_once_with("some-arn") @@ -1430,14 +1546,133 @@ def test_multiple_model_id_versions_found_aliases_inconsistent(self): {"Key": random.choice(EXTRA_MODEL_VERSION_TAGS), "Value": "model_version_2"}, ] - self.assertEquals( - utils.get_jumpstart_model_id_version_from_resource_arn( - "some-arn", mock_sagemaker_session - ), - (None, None), + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + (None, None, None, None), ) mock_list_tags.assert_called_once_with("some-arn") + def test_multiple_config_names_found_aliases_inconsistent(self): + mock_list_tags = Mock() + mock_sagemaker_session = Mock() + mock_sagemaker_session.list_tags = mock_list_tags + mock_list_tags.return_value = [ + {"Key": "blah", "Value": "blah1"}, + {"Key": JumpStartTag.MODEL_ID, "Value": "model_id_1"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "model_version_1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_1"}, + {"Key": JumpStartTag.INFERENCE_CONFIG_NAME, "Value": "config_name_2"}, + ] + + self.assertEqual( + utils.get_jumpstart_model_info_from_resource_arn("some-arn", mock_sagemaker_session), + ("model_id_1", "model_version_1", None, None), + ) + mock_list_tags.assert_called_once_with("some-arn") + + +class TestJumpStartLaunchedRegions(TestCase): + def test_regions_not_empty(self): + self.assertTrue(len(JUMPSTART_LAUNCHED_REGIONS) > 0) + + +class TestLoadRegionConfig(TestCase): + def setUp(self): + # Sample valid config that matches the expected structure + self.valid_config = { + "us-east-1": { + "content_bucket": "jumpstart-cache-prod-us-east-1", + "gated_content_bucket": "jumpstart-private-cache-prod-us-east-1", + "neo_content_bucket": "jumpstart-neo-cache-prod-us-east-1", + }, + "us-west-2": { + "content_bucket": "jumpstart-cache-prod-us-west-2", + }, + } + self.config_json = json.dumps(self.valid_config) + + @patch("builtins.open", new_callable=mock_open) + def test_successful_config_load(self, mock_file): + # Setup mock to return valid config + mock_file.return_value.__enter__().read.return_value = self.config_json + + result = _load_region_config("dummy/path") + + # Verify the returned dictionary contains JumpStartLaunchedRegionInfo objects + self.assertTrue(all(isinstance(region, JumpStartLaunchedRegionInfo) for region in result)) + + for region in result: + if region.region_name == "us-east-1": + self.assertEqual(region.region_name, "us-east-1") + self.assertEqual(region.content_bucket, "jumpstart-cache-prod-us-east-1") + self.assertEqual( + region.gated_content_bucket, "jumpstart-private-cache-prod-us-east-1" + ) + self.assertEqual(region.neo_content_bucket, "jumpstart-neo-cache-prod-us-east-1") + + elif region.region_name == "us-west-2": + self.assertEqual(region.region_name, "us-west-2") + self.assertEqual(region.content_bucket, "jumpstart-cache-prod-us-west-2") + self.assertIsNone(region.gated_content_bucket) + self.assertIsNone(region.neo_content_bucket) + else: + raise AssertionError(f"Unexpected region name found: {region.region_name}") + + @patch("builtins.open", new_callable=mock_open) + def test_missing_required_field(self, mock_file): + # Config missing required content_bucket field + invalid_config = { + "us-east-1": { + "gated_content_bucket": "XXXXXXXXXXX", + "neo_content_bucket": "some-other-bucket", + } + } + mock_file.return_value.__enter__().read.return_value = json.dumps(invalid_config) + + # Should return empty dict due to exception handling + result = _load_region_config("dummy/path") + self.assertEqual(result, set()) + + @patch("builtins.open") + def test_file_not_found(self, mock_file): + # Simulate file not found + mock_file.side_effect = FileNotFoundError() + + # Should return empty dict due to exception handling + result = _load_region_config("dummy/path") + self.assertEqual(result, set()) + + @patch("builtins.open", new_callable=mock_open) + def test_invalid_json(self, mock_file): + # Setup mock to return invalid JSON + mock_file.return_value.__enter__().read.return_value = "invalid json content" + + # Should return empty dict due to exception handling + result = _load_region_config("dummy/path") + self.assertEqual(result, set()) + + @patch("builtins.open", new_callable=mock_open) + def test_empty_config(self, mock_file): + # Setup mock to return empty JSON object + mock_file.return_value.__enter__().read.return_value = "{}" + + result = _load_region_config("dummy/path") + self.assertEqual(result, set()) + + @patch("sagemaker.jumpstart.constants.JUMPSTART_LOGGER") + @patch("builtins.open") + def test_logging_on_error(self, mock_file, mock_logger): + + # Simulate an error + mock_file.side_effect = Exception("Test error") + + result = _load_region_config("dummy/path") + + self.assertEqual(result, set()) + + # Verify error was logged + mock_logger.error.assert_called_once() + class TestJumpStartLogger(TestCase): @patch.dict("os.environ", {}) @@ -1529,6 +1764,8 @@ def test_get_jumpstart_config_names_success( "neuron-inference-budget", "gpu-inference-budget", "gpu-inference", + "gpu-inference-model-package", + "gpu-accelerated", ] @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1599,22 +1836,44 @@ def test_get_jumpstart_benchmark_stats_full_list( ) == { "neuron-inference": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "neuron-inference-budget": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "gpu-inference-budget": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "gpu-inference": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ] + }, + "gpu-inference-model-package": { + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ] + }, + "gpu-accelerated": { + "ml.p3.2xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, } @@ -1634,12 +1893,16 @@ def test_get_jumpstart_benchmark_stats_partial_list( ) == { "neuron-inference-budget": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, "gpu-inference-budget": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] }, } @@ -1659,7 +1922,9 @@ def test_get_jumpstart_benchmark_stats_single_stat( ) == { "neuron-inference-budget": { "ml.inf2.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ] } } @@ -1687,6 +1952,16 @@ def test_get_jumpstart_benchmark_stats_training( ): patched_get_model_specs.side_effect = get_base_spec_with_prototype_configs + print( + utils.get_benchmark_stats( + "mock-region", + "mock-model", + "mock-model-version", + scope=JumpStartScriptScope.TRAINING, + config_names=["neuron-training", "gpu-training-budget"], + ) + ) + assert utils.get_benchmark_stats( "mock-region", "mock-model", @@ -1696,15 +1971,569 @@ def test_get_jumpstart_benchmark_stats_training( ) == { "neuron-training": { "ml.tr1n1.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) ], "ml.tr1n1.4xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "50", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "50", "unit": "Tokens/S", "concurrency": 1} + ) ], }, "gpu-training-budget": { "ml.p3.2xlarge": [ - JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": "1"} + ) ] }, } + + +class TestUserAgent: + @patch("sagemaker.jumpstart.utils.os.getenv") + def test_get_jumpstart_user_agent_extra_suffix(self, mock_getenv): + mock_getenv.return_value = False + assert utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", None, "False" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version") + mock_getenv.return_value = None + assert utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", None, "False" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version") + mock_getenv.return_value = "True" + assert not utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", None, "True" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#True") + mock_getenv.return_value = True + assert not utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", None, "True" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#True") + mock_getenv.return_value = False + assert utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", "some-config", "False" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_config#some-config") + + @patch("sagemaker.jumpstart.utils.botocore.session") + @patch("sagemaker.jumpstart.utils.botocore.config.Config") + @patch("sagemaker.jumpstart.utils.get_jumpstart_user_agent_extra_suffix") + @patch("sagemaker.jumpstart.utils.boto3.Session") + @patch("sagemaker.jumpstart.utils.boto3.client") + @patch("sagemaker.jumpstart.utils.Session") + def test_get_default_jumpstart_session_with_user_agent_suffix( + self, + mock_sm_session, + mock_boto3_client, + mock_botocore_session, + mock_get_jumpstart_user_agent_extra_suffix, + mock_botocore_config, + mock_boto3_session, + ): + utils.get_default_jumpstart_session_with_user_agent_suffix("model_id", "model_version") + mock_boto3_session.get_session.assert_called_once_with() + mock_get_jumpstart_user_agent_extra_suffix.assert_called_once_with( + model_id="model_id", + model_version="model_version", + config_name=None, + is_hub_content=False, + ) + mock_botocore_config.assert_called_once_with( + user_agent_extra=mock_get_jumpstart_user_agent_extra_suffix.return_value + ) + mock_botocore_session.assert_called_once_with( + region_name=JUMPSTART_DEFAULT_REGION_NAME, + botocore_session=mock_boto3_session.get_session.return_value, + ) + mock_boto3_client.assert_has_calls( + [ + call( + "sagemaker", + region_name=JUMPSTART_DEFAULT_REGION_NAME, + config=mock_botocore_config.return_value, + ), + call( + "sagemaker-runtime", + region_name=JUMPSTART_DEFAULT_REGION_NAME, + config=mock_botocore_config.return_value, + ), + ], + any_order=True, + ) + + @patch("botocore.client.BaseClient._make_request") + def test_get_default_jumpstart_session_with_user_agent_suffix_http_header( + self, + mock_make_request, + ): + session = utils.get_default_jumpstart_session_with_user_agent_suffix( + "model_id", "model_version" + ) + try: + session.sagemaker_client.list_endpoints() + except Exception: + pass + + assert ( + "md/js_model_id#model_id md/js_model_ver#model_version" + in mock_make_request.call_args[0][1]["headers"]["User-Agent"] + ) + + +def test_extract_metrics_from_deployment_configs(): + configs = get_base_deployment_configs_metadata() + configs[0].benchmark_metrics = None + configs[2].deployment_args = None + + data = utils.get_metrics_from_deployment_configs(configs) + + for key in data: + assert len(data[key]) == (len(configs) - 2) + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + } + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ], + "ml.gd4.xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ], + }, + ) + + assert err is None + for key in out: + assert len(out[key]) == 2 + for metric in out[key]: + if metric.name == "Instance Rate": + assert metric.to_json() == { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + "concurrency": None, + } + + +def test__normalize_benchmark_metrics(): + rate, metrics = utils._normalize_benchmark_metrics( + [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + JumpStartBenchmarkStat( + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76", "concurrency": None} + ), + ] + ) + + assert rate == JumpStartBenchmarkStat( + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76", "concurrency": None} + ) + assert metrics == { + 1: [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ), + ], + 2: [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + JumpStartBenchmarkStat( + {"name": "Throughput", "value": "100", "unit": "Tokens/S", "concurrency": 2} + ), + ], + } + + +@pytest.mark.parametrize( + "name, unit, expected", + [ + ("latency", "sec", "Latency, TTFT (P50 in sec)"), + ("throughput", "tokens/sec", "Throughput (P50 in tokens/sec/user)"), + ], +) +def test_normalize_benchmark_metric_column_name(name, unit, expected): + out = utils._normalize_benchmark_metric_column_name(name, unit) + + assert out == expected + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics_client_ex( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = ClientError( + { + "Error": { + "Message": "is not authorized to perform: pricing:GetProducts", + "Code": "AccessDenied", + }, + }, + "GetProducts", + ) + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": 1} + ) + ], + }, + ) + + assert err["Message"] == "is not authorized to perform: pricing:GetProducts" + assert err["Code"] == "AccessDenied" + for key in out: + assert len(out[key]) == 1 + + +@pytest.mark.parametrize( + "stats, expected", + [ + (None, True), + ( + [ + JumpStartBenchmarkStat( + { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + "concurrency": None, + } + ) + ], + True, + ), + ( + [ + JumpStartBenchmarkStat( + {"name": "Latency", "value": "100", "unit": "Tokens/S", "concurrency": None} + ) + ], + False, + ), + ], +) +def test_has_instance_rate_stat(stats, expected): + assert utils.has_instance_rate_stat(stats) is expected + + +def test_get_latest_version(): + assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0"]) == "2.16.0" + + +def test_get_latest_version_empty_list_is_none(): + assert utils.get_latest_version([]) is None + + +def test_get_latest_version_none_is_none(): + assert utils.get_latest_version(None) is None + + +def test_get_latest_version_with_invalid_sem_ver(): + assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0", "abc"]) == "abc" + + +@pytest.mark.parametrize( + "data, expected", + [(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())], +) +def test_deployment_config_response_data(data, expected): + out = utils.deployment_config_response_data(data) + assert out == expected + + +class TestGetEulaMessage(TestCase): + mock_model_specs = Mock(model_id="some-model-id", hosting_eula_key="some-eula-key") + + def test_get_domain_for_region(self): + self.assertEqual( + utils.get_eula_message(self.mock_model_specs, "us-west-2"), + "Model 'some-model-id' requires accepting end-user license agreement (EULA). See" + " https://jumpstart-cache-prod-us-west-2.s3.us-west-2.amazonaws.com/some-eula-key " + "for terms of use.", + ) + self.assertEqual( + utils.get_eula_message(self.mock_model_specs, "cn-north-1"), + "Model 'some-model-id' requires accepting end-user license agreement (EULA). See" + " https://jumpstart-cache-prod-cn-north-1.s3.cn-north-1.amazonaws.com.cn/some-eula-key " + "for terms of use.", + ) + + +class TestAcceptEulaModelAccessConfig(TestCase): + MOCK_PUBLIC_MODEL_ID = "mock_public_model_id" + MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/public/resources/", + }, + "HostingEulaKey": None, + } + ] + MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/public/resources/", + }, + } + ] + MOCK_GATED_MODEL_ID = "mock_gated_model_id" + MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/", + }, + "HostingEulaKey": "fmhMetadata/eula/llama3_2Eula.txt", + } + ] + MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/", + "ModelAccessConfig": {"AcceptEula": True}, + }, + } + ] + + # Public Positive Cases + + def test_public_additional_model_data_source_should_pass_through(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs=None, + model_id=self.MOCK_PUBLIC_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert ( + additional_model_data_sources + == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + def test_multiple_public_additional_model_data_source_should_pass_through_both(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs=None, + model_id=self.MOCK_PUBLIC_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == ( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + def test_public_additional_model_data_source_with_model_access_config_should_ignore_it(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)}, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert ( + additional_model_data_sources + == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + def test_no_additional_model_data_source_should_pass_through(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=None, + model_access_configs=None, + model_id=self.MOCK_PUBLIC_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert not additional_model_data_sources + + # Gated Positive Cases + + def test_gated_additional_model_data_source_should_accept_it(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)}, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert ( + additional_model_data_sources + == self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + def test_multiple_gated_additional_model_data_source_should_accept_both(self): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs={ + self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True), + self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True), + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == ( + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + def test_gated_additional_model_data_source_already_accepted_with_no_hosting_eula_key_should_pass_through( + self, + ): + mock_gated_deploy_config_additional_model_data_pre_accepted = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/", + "ModelAccessConfig": {"AcceptEula": True}, + }, + } + ] + + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=mock_gated_deploy_config_additional_model_data_pre_accepted, + model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=False)}, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # Mixed Positive Cases + + def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other( + self, + ): + # WHERE / WHEN + additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)}, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + # THEN + assert additional_model_data_sources == ( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL + ) + + # Test Gated Negative Tests + + def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error( + self, + ): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs=None, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + def test_multiple_mixed_additional_no_model_data_source_should_raise_value_error(self): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=( + self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + + self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL + ), + model_access_configs=None, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error( + self, + ): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={ + self.MOCK_PUBLIC_MODEL_ID: ModelAccessConfig(accept_eula=True) + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) + + def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error( + self, + ): + # WHERE / WHEN / THEN + with self.assertRaises(ValueError): + utils._add_model_access_configs_to_model_data_sources( + model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL, + model_access_configs={ + self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=False) + }, + model_id=self.MOCK_GATED_MODEL_ID, + region=JUMPSTART_DEFAULT_REGION_NAME, + ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index e102251060..bd870dc461 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -12,9 +12,10 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy -from typing import List +from typing import List, Dict, Any, Optional import boto3 +from sagemaker.compute_resource_requirements import ResourceRequirements from sagemaker.jumpstart.cache import JumpStartModelsCache from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, @@ -22,11 +23,16 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, + JumpStartModelInitKwargs, + DeploymentConfigMetadata, + JumpStartModelDeployKwargs, + JumpStartBenchmarkStat, + JumpStartAdditionalDataSources, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest @@ -43,6 +49,8 @@ SPECIAL_MODEL_SPECS_DICT, TRAINING_CONFIG_RANKINGS, TRAINING_CONFIGS, + DEPLOYMENT_CONFIGS, + INIT_KWARGS, ) @@ -108,7 +116,9 @@ def get_prototype_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -124,7 +134,9 @@ def get_special_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -140,7 +152,9 @@ def get_special_model_spec_for_inference_component_based_endpoint( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID and adding @@ -163,8 +177,10 @@ def get_spec_from_base_spec( model_id: str = None, version_str: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: if version and version_str: @@ -187,6 +203,7 @@ def get_spec_from_base_spec( "catboost" not in model_id, "lightgbm" not in model_id, "sklearn" not in model_id, + "ai21" not in model_id, ] ): raise KeyError("Bad model ID") @@ -209,8 +226,10 @@ def get_base_spec_with_prototype_configs( region: str = None, model_id: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: spec = copy.deepcopy(BASE_SPEC) inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} @@ -222,35 +241,72 @@ def get_base_spec_with_prototype_configs( return JumpStartModelSpecs(spec) +def get_base_spec_with_prototype_configs_with_missing_benchmarks( + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, +) -> JumpStartModelSpecs: + spec = copy.deepcopy(BASE_SPEC) + copy_inference_configs = copy.deepcopy(INFERENCE_CONFIGS) + copy_inference_configs["inference_configs"]["neuron-inference"]["benchmark_metrics"] = None + + inference_configs = {**copy_inference_configs, **INFERENCE_CONFIG_RANKINGS} + training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS} + + spec.update(inference_configs) + spec.update(training_configs) + + return JumpStartModelSpecs(spec) + + +def get_prototype_spec_with_configs( + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: str = None, + sagemaker_session: boto3.Session = None, +) -> JumpStartModelSpecs: + spec = copy.deepcopy(PROTOTYPICAL_MODEL_SPECS_DICT[model_id]) + inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS} + + spec.update(inference_configs) + spec.update(training_configs) + + return JumpStartModelSpecs(spec) + + def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, - key: JumpStartCachedS3ContentKey, - value: JumpStartCachedS3ContentValue, -) -> JumpStartCachedS3ContentValue: + key: JumpStartCachedContentKey, + value: JumpStartCachedContentValue, +) -> JumpStartCachedContentValue: - filetype, s3_key = key.file_type, key.s3_key - if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: + data_type, id_info = key.data_type, key.id_info + if data_type == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: - return JumpStartCachedS3ContentValue( - formatted_content=get_formatted_manifest(BASE_MANIFEST) - ) + return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST)) - if filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: - _, model_id, specs_version = s3_key.split("/") + if data_type == JumpStartS3FileType.OPEN_WEIGHT_SPECS: + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) - if filetype == JumpStartS3FileType.PROPRIETARY_MANIFEST: - return JumpStartCachedS3ContentValue( + if data_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return JumpStartCachedContentValue( formatted_content=get_formatted_manifest(BASE_PROPRIETARY_MANIFEST) ) - if filetype == JumpStartS3FileType.PROPRIETARY_SPECS: - _, model_id, specs_version = s3_key.split("/") + if data_type == JumpStartS3FileType.PROPRIETARY_SPECS: + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("proprietary_specs_", "").replace(".json", "") - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec( model_id=model_id, version=version, @@ -258,7 +314,7 @@ def patched_retrieval_function( ) ) - raise ValueError(f"Bad value for filetype: {filetype}") + raise ValueError(f"Bad value for filetype: {data_type}") def overwrite_dictionary( @@ -280,3 +336,128 @@ def overwrite_dictionary( base_dictionary[key] = value return base_dictionary + + +def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, Any]]: + configs = copy.deepcopy(DEPLOYMENT_CONFIGS) + configs[0]["AccelerationConfigs"] = [ + {"Type": "Speculative-Decoding", "Enabled": True, "Spec": {"Version": "0.1"}} + ] + return configs + + +def get_mock_init_kwargs( + model_id: str, config_name: Optional[str] = None +) -> JumpStartModelInitKwargs: + kwargs = JumpStartModelInitKwargs( + model_id=model_id, + model_type=JumpStartModelType.OPEN_WEIGHTS, + model_data=INIT_KWARGS.get("model_data"), + image_uri=INIT_KWARGS.get("image_uri"), + instance_type=INIT_KWARGS.get("instance_type"), + env=INIT_KWARGS.get("env"), + resources=ResourceRequirements(), + config_name=config_name, + ) + setattr(kwargs, "model_reference_arn", None) + setattr(kwargs, "hub_content_type", None) + return kwargs + + +def get_base_deployment_configs_metadata( + omit_benchmark_metrics: bool = False, +) -> List[DeploymentConfigMetadata]: + specs = ( + get_base_spec_with_prototype_configs_with_missing_benchmarks() + if omit_benchmark_metrics + else get_base_spec_with_prototype_configs() + ) + configs = [] + for config_name in specs.inference_configs.config_rankings.get("overall").rankings: + jumpstart_config = specs.inference_configs.configs.get(config_name) + benchmark_metrics = jumpstart_config.benchmark_metrics + + if benchmark_metrics: + for instance_type in benchmark_metrics: + benchmark_metrics[instance_type].append( + JumpStartBenchmarkStat( + { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + "concurrency": None, + } + ) + ) + + configs.append( + DeploymentConfigMetadata( + config_name=config_name, + metadata_config=jumpstart_config, + init_kwargs=get_mock_init_kwargs( + get_base_spec_with_prototype_configs().model_id, config_name + ), + deploy_kwargs=JumpStartModelDeployKwargs( + model_id=get_base_spec_with_prototype_configs().model_id, + ), + ) + ) + return configs + + +def get_base_deployment_configs( + omit_benchmark_metrics: bool = False, +) -> List[Dict[str, Any]]: + configs = [] + for config in get_base_deployment_configs_metadata(omit_benchmark_metrics): + config_json = config.to_json() + if config_json["BenchmarkMetrics"]: + config_json["BenchmarkMetrics"] = { + config.deployment_args.instance_type: config_json["BenchmarkMetrics"].get( + config.deployment_args.instance_type + ) + } + configs.append(config_json) + return configs + + +def append_instance_stat_metrics( + metrics: Dict[str, List[JumpStartBenchmarkStat]] +) -> Dict[str, List[JumpStartBenchmarkStat]]: + if metrics is not None: + for key in metrics: + metrics[key].append( + JumpStartBenchmarkStat( + { + "name": "Instance Rate", + "value": "3.76", + "unit": "USD/Hrs", + "concurrency": None, + } + ) + ) + return metrics + + +def append_gated_draft_model_specs_to_jumpstart_model_spec(*args, **kwargs): + augmented_spec = get_prototype_model_spec(*args, **kwargs) + + gated_s3_uri = "meta-textgeneration/meta-textgeneration-llama-3-2-1b-instruct/artifacts/inference-prepack/v1.0.0/" + augmented_spec.hosting_additional_data_sources = JumpStartAdditionalDataSources( + spec={ + "speculative_decoding": [ + { + "channel_name": "draft_model", + "provider": {"name": "JumpStart", "classification": "gated"}, + "artifact_version": "v1", + "hosting_eula_key": "fmhMetadata/eula/llama3_2Eula.txt", + "s3_data_source": { + "s3_uri": gated_s3_uri, + "compression_type": "None", + "s3_data_type": "S3Prefix", + }, + } + ] + } + ) + return augmented_spec diff --git a/tests/unit/sagemaker/local/test_local_entities.py b/tests/unit/sagemaker/local/test_local_entities.py index 6a026c316b..74a361cf73 100644 --- a/tests/unit/sagemaker/local/test_local_entities.py +++ b/tests/unit/sagemaker/local/test_local_entities.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import re import os import pytest @@ -290,10 +291,10 @@ def test_start_local_pipeline_with_wrong_parameter_type(sagemaker_local_session) local_pipeline = sagemaker.local.entities._LocalPipeline(pipeline) with pytest.raises(ClientError) as error: local_pipeline.start(PipelineParameters={"MyStr": True}) - assert ( - f"Unexpected type for parameter '{parameter.name}'. Expected " - f"{parameter.parameter_type.python_type} but found {type(True)}." in str(error.value) + expected_error_pattern = ( + r"Unexpected type for parameter 'MyStr'\. Expected .* but found \." ) + assert re.search(expected_error_pattern, str(error.value)) def test_start_local_pipeline_with_empty_parameter_string_value( diff --git a/tests/unit/sagemaker/local/test_local_image.py b/tests/unit/sagemaker/local/test_local_image.py index 3142fa6dfa..3bb15dc43d 100644 --- a/tests/unit/sagemaker/local/test_local_image.py +++ b/tests/unit/sagemaker/local/test_local_image.py @@ -160,7 +160,7 @@ def test_get_compose_cmd_prefix_with_docker_cli(): "subprocess.check_output", side_effect=subprocess.CalledProcessError(returncode=1, cmd="docker compose version"), ) -@patch("sagemaker.local.image.find_executable", Mock(return_value="/usr/bin/docker-compose")) +@patch("sagemaker.local.image.shutil.which", Mock(return_value="/usr/bin/docker-compose")) def test_get_compose_cmd_prefix_with_docker_compose_cli(check_output): compose_cmd_prefix = _SageMakerContainer._get_compose_cmd_prefix() assert compose_cmd_prefix == ["docker-compose"] @@ -170,7 +170,7 @@ def test_get_compose_cmd_prefix_with_docker_compose_cli(check_output): "subprocess.check_output", side_effect=subprocess.CalledProcessError(returncode=1, cmd="docker compose version"), ) -@patch("sagemaker.local.image.find_executable", Mock(return_value=None)) +@patch("sagemaker.local.image.shutil.which", Mock(return_value=None)) def test_get_compose_cmd_prefix_raises_import_error(check_output): with pytest.raises(ImportError) as e: _SageMakerContainer._get_compose_cmd_prefix() diff --git a/tests/unit/sagemaker/local/test_local_session.py b/tests/unit/sagemaker/local/test_local_session.py index ceae674704..ce8fd19b5c 100644 --- a/tests/unit/sagemaker/local/test_local_session.py +++ b/tests/unit/sagemaker/local/test_local_session.py @@ -47,7 +47,8 @@ @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -142,7 +143,8 @@ def test_create_processing_job(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_not_fully_replicated(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_not_fully_replicated(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -197,7 +199,8 @@ def test_create_processing_job_not_fully_replicated(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_invalid_upload_mode(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_invalid_upload_mode(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -252,7 +255,8 @@ def test_create_processing_job_invalid_upload_mode(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_invalid_processing_input(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_invalid_processing_input(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -302,7 +306,8 @@ def test_create_processing_job_invalid_processing_input(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_invalid_processing_output(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_invalid_processing_output(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -360,7 +365,8 @@ def test_describe_invalid_processing_job(*args): @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") @patch("sagemaker.local.local_session.LocalSession") -def test_create_training_job(train, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_training_job(train, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -427,7 +433,8 @@ def test_describe_invalid_training_job(*args): @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") @patch("sagemaker.local.local_session.LocalSession") -def test_create_training_job_invalid_data_source(train, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_training_job_invalid_data_source(train, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -466,7 +473,8 @@ def test_create_training_job_invalid_data_source(train, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") @patch("sagemaker.local.local_session.LocalSession") -def test_create_training_job_not_fully_replicated(train, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_training_job_not_fully_replicated(train, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -503,7 +511,8 @@ def test_create_training_job_not_fully_replicated(train, LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_create_model(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_model(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER) @@ -512,7 +521,8 @@ def test_create_model(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_delete_model(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_delete_model(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER) @@ -523,7 +533,8 @@ def test_delete_model(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_describe_model(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_describe_model(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() with pytest.raises(ClientError): @@ -536,9 +547,10 @@ def test_describe_model(LocalSession): assert response["PrimaryContainer"]["ModelDataUrl"] == "/some/model/path" +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) @patch("sagemaker.local.local_session._LocalTransformJob") @patch("sagemaker.local.local_session.LocalSession") -def test_create_transform_job(LocalSession, _LocalTransformJob): +def test_create_transform_job(LocalSession, _LocalTransformJob, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_transform_job("transform-job", "some-model", None, None, None) @@ -572,7 +584,8 @@ def test_logs_for_processing_job(process, LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_describe_endpoint_config(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_describe_endpoint_config(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() # No Endpoint Config Created @@ -588,7 +601,8 @@ def test_describe_endpoint_config(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_create_endpoint_config(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_endpoint_config(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS) @@ -598,7 +612,8 @@ def test_create_endpoint_config(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_delete_endpoint_config(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_delete_endpoint_config(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS) @@ -613,12 +628,15 @@ def test_delete_endpoint_config(LocalSession): ) +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) @patch("sagemaker.local.image._SageMakerContainer.serve") @patch("sagemaker.local.local_session.LocalSession") @patch("urllib3.PoolManager.request") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_model") -def test_describe_endpoint(describe_model, describe_endpoint_config, request, *args): +def test_describe_endpoint( + describe_model, describe_endpoint_config, request, mock_telemetry, *args +): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() request.return_value = OK_RESPONSE @@ -658,12 +676,13 @@ def test_describe_endpoint(describe_model, describe_endpoint_config, request, *a assert response["EndpointName"] == "test-endpoint" +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) @patch("sagemaker.local.image._SageMakerContainer.serve") @patch("sagemaker.local.local_session.LocalSession") @patch("urllib3.PoolManager.request") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_model") -def test_create_endpoint(describe_model, describe_endpoint_config, request, *args): +def test_create_endpoint(describe_model, describe_endpoint_config, request, mock_telemetry, *args): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() request.return_value = OK_RESPONSE diff --git a/tests/unit/sagemaker/local/test_local_utils.py b/tests/unit/sagemaker/local/test_local_utils.py index 42710b2495..82e3207266 100644 --- a/tests/unit/sagemaker/local/test_local_utils.py +++ b/tests/unit/sagemaker/local/test_local_utils.py @@ -85,11 +85,11 @@ def test_move_to_destination_illegal_destination(): @patch("sagemaker.local.utils.os.path") -@patch("sagemaker.local.utils.copy_tree") +@patch("sagemaker.local.utils.shutil.copytree") def test_recursive_copy(copy_tree, m_os_path): m_os_path.isdir.return_value = True sagemaker.local.utils.recursive_copy("source", "destination") - copy_tree.assert_called_with("source", "destination") + copy_tree.assert_called_with("source", "destination", dirs_exist_ok=True) @patch("sagemaker.local.utils.os") @@ -135,6 +135,68 @@ def test_get_docker_host(m_subprocess): assert host == endpoint["result"] +@patch("sagemaker.local.utils.subprocess") +def test_get_docker_host_rootless_docker(m_subprocess): + """Test that rootless Docker is detected and returns fixed IP""" + # Mock docker info process for rootless Docker + info_process_mock = Mock() + info_attrs = {"communicate.return_value": (b"Cgroup Driver: none", b""), "returncode": 0} + info_process_mock.configure_mock(**info_attrs) + m_subprocess.Popen.return_value = info_process_mock + + host = sagemaker.local.utils.get_docker_host() + assert host == "172.17.0.1" + + # Verify docker info was called + m_subprocess.Popen.assert_called_with( + ["docker", "info"], stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE + ) + + +@patch("sagemaker.local.utils.subprocess") +def test_get_docker_host_traditional_docker(m_subprocess): + """Test that traditional Docker falls back to existing logic""" + scenarios = [ + { + "docker_info": b"Cgroup Driver: cgroupfs", + "context_host": "tcp://host:port", + "result": "host", + }, + { + "docker_info": b"Cgroup Driver: cgroupfs", + "context_host": "unix:///var/run/docker.sock", + "result": "localhost", + }, + { + "docker_info": b"Cgroup Driver: cgroupfs", + "context_host": "fd://something", + "result": "localhost", + }, + ] + + for scenario in scenarios: + # Mock docker info process for traditional Docker + info_process_mock = Mock() + info_attrs = {"communicate.return_value": (scenario["docker_info"], b""), "returncode": 0} + info_process_mock.configure_mock(**info_attrs) + + # Mock docker context inspect process + context_return_value = ( + '[\n{\n"Endpoints":{\n"docker":{\n"Host": "%s"}\n}\n}\n]\n' % scenario["context_host"] + ) + context_process_mock = Mock() + context_attrs = { + "communicate.return_value": (context_return_value.encode("utf-8"), None), + "returncode": 0, + } + context_process_mock.configure_mock(**context_attrs) + + m_subprocess.Popen.side_effect = [info_process_mock, context_process_mock] + + host = sagemaker.local.utils.get_docker_host() + assert host == scenario["result"] + + @pytest.mark.parametrize( "json_path, expected", [ diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index 835a09a58c..12d3a2169d 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -59,6 +59,8 @@ def test_jumpstart_default_metric_definitions( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -79,6 +81,8 @@ def test_jumpstart_default_metric_definitions( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/mlflow/__init__.py b/tests/unit/sagemaker/mlflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py b/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py new file mode 100644 index 0000000000..14502880c3 --- /dev/null +++ b/tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py @@ -0,0 +1,279 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import absolute_import +from unittest.mock import patch, MagicMock, Mock +import json +import pytest +from mlflow.entities import Metric, Param +import requests + + +from sagemaker.mlflow.forward_sagemaker_metrics import ( + encode, + log_sagemaker_job_to_mlflow, + decode, + prepare_mlflow_metrics, + prepare_mlflow_params, + batch_items, + create_metric_queries, + get_metric_data, + log_to_mlflow, + get_training_job_details, +) + + +@pytest.fixture +def mock_boto3_client(): + with patch("boto3.client") as mock_client: + yield mock_client + + +@pytest.fixture +def mock_mlflow_client(): + with patch("mlflow.MlflowClient") as mock_client: + yield mock_client + + +def test_encode(): + existing_names = set() + assert encode("test-name", existing_names) == "test-name" + assert encode("test:name", existing_names) == "test:name" + assert encode("test-name", existing_names) == "test-name_1" + + +def test_encode_colon_allowed(): + # Test case where colon is allowed (Unix-like system and MLflow >= 2.16.0) + with patch("platform.system") as mock_system, patch("mlflow.__version__", new="2.16.0"): + + mock_system.return_value = "Darwin" # MacOS + existing_names = set() + + assert encode("test:name", existing_names) == "test:name" + assert encode("test/name", existing_names) == "test/name" + assert encode("test name", existing_names) == "test name" + assert encode("test@name", existing_names) == "test_40_name" + + # Test name longer than 250 characters + long_name = "a" * 250 + encoded_long_name = encode(long_name, existing_names) + assert len(encoded_long_name) == 250 + assert encoded_long_name == "a" * 250 + + # Test suffix addition for duplicate names + assert encode("duplicate", existing_names) == "duplicate" + assert encode("duplicate", existing_names) == "duplicate_1" + assert encode("duplicate", existing_names) == "duplicate_2" + + +def test_decode(): + assert decode("test_3a_name") == "test:name" + assert decode("normal_name") == "normal_name" + + +def test_get_training_job_details(mock_boto3_client): + mock_sagemaker = MagicMock() + mock_boto3_client.return_value = mock_sagemaker + mock_sagemaker.describe_training_job.return_value = {"JobName": "test-job"} + + result = get_training_job_details( + "arn:aws:sagemaker:us-west-2:123456789012:training-job/test-job" + ) + assert result == {"JobName": "test-job"} + mock_sagemaker.describe_training_job.assert_called_once_with(TrainingJobName="test-job") + + +def test_create_metric_queries(): + job_arn = "arn:aws:sagemaker:us-west-2:123456789012:training-job/test-job" + metric_definitions = [{"Name": "loss"}, {"Name": "accuracy"}] + result = create_metric_queries(job_arn, metric_definitions) + assert len(result) == 2 + assert result[0]["MetricName"] == "loss" + assert result[1]["MetricName"] == "accuracy" + + +def test_get_metric_data(mock_boto3_client): + mock_metrics = MagicMock() + mock_boto3_client.return_value = mock_metrics + mock_metrics.batch_get_metrics.return_value = {"MetricResults": []} + + metric_queries = [{"MetricName": "loss"}] + result = get_metric_data(metric_queries) + assert result == {"MetricResults": []} + mock_metrics.batch_get_metrics.assert_called_once_with(MetricQueries=metric_queries) + + +def test_prepare_mlflow_metrics(): + metric_queries = [{"MetricName": "loss"}, {"MetricName": "accuracy!"}] + metric_results = [ + {"Status": "Complete", "XAxisValues": [1, 2], "MetricValues": [0.1, 0.2]}, + {"Status": "Complete", "XAxisValues": [1, 2], "MetricValues": [0.8, 0.9]}, + ] + expected_encoded = {"loss": "loss", "accuracy_21_": "accuracy!"} + + metrics, mapping = prepare_mlflow_metrics(metric_queries, metric_results) + + assert len(metrics) == sum(len(result["MetricValues"]) for result in metric_results) + + expected_metrics = [ + ("loss", 0.1, 1, 0), + ("loss", 0.2, 2, 1), + ("accuracy_21_", 0.8, 1, 0), + ("accuracy_21_", 0.9, 2, 1), + ] + + for metric, (exp_key, exp_value, exp_timestamp, exp_step) in zip(metrics, expected_metrics): + assert metric.key == exp_key + assert metric.value == exp_value + assert metric.timestamp == exp_timestamp + assert metric.step == exp_step + + assert mapping == {v: k for v, k in expected_encoded.items()} + + +def test_prepare_mlflow_params(): + hyperparameters = {"learning_rate": "0.01", "batch_!size": "32"} + expected_encoded = {"learning_rate": "learning_rate", "batch__21_size": "batch_!size"} + + params, mapping = prepare_mlflow_params(hyperparameters) + + assert len(params) == len(hyperparameters) + + for param in params: + assert param.key in expected_encoded + assert param.value == hyperparameters[mapping[param.key]] + + assert mapping == {v: k for v, k in expected_encoded.items()} + + +def test_batch_items(): + items = [1, 2, 3, 4, 5] + batches = list(batch_items(items, 2)) + assert batches == [[1, 2], [3, 4], [5]] + + +@patch("os.getenv") +@patch("requests.Session.request") +def test_log_to_mlflow(mock_request, mock_getenv): + # Set up return values for os.getenv calls + def getenv_side_effect(arg, default=None): + values = { + "MLFLOW_TRACKING_URI": "https://test.sagemaker.aws", + "MLFLOW_REGISTRY_URI": "https://registry.uri", + "MLFLOW_EXPERIMENT_NAME": "test_experiment", + "MLFLOW_ALLOW_HTTP_REDIRECTS": "true", + } + return values.get(arg, default) + + mock_getenv.side_effect = getenv_side_effect + + # Mock the HTTP requests + mock_responses = { + "https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name": Mock( + spec=requests.Response + ), + "https://test.sagemaker.aws/api/2.0/mlflow/runs/create": Mock(spec=requests.Response), + "https://test.sagemaker.aws/api/2.0/mlflow/runs/update": Mock(spec=requests.Response), + "https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch": [ + Mock(spec=requests.Response), + Mock(spec=requests.Response), + Mock(spec=requests.Response), + ], + "https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate": Mock(spec=requests.Response), + } + + mock_responses[ + "https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name" + ].status_code = 200 + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name"].text = ( + json.dumps( + { + "experiment_id": "existing_experiment_id", + "name": "test_experiment", + "artifact_location": "some/path", + "lifecycle_stage": "active", + "tags": {}, + } + ) + ) + + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/create"].status_code = 200 + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/create"].text = json.dumps( + {"run_id": "test_run_id"} + ) + + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/update"].status_code = 200 + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/update"].text = json.dumps( + {"run_id": "test_run_id"} + ) + + for mock_response in mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch"]: + mock_response.status_code = 200 + mock_response.text = json.dumps({}) + + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate"].status_code = 200 + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate"].text = json.dumps({}) + + mock_request.side_effect = [ + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/experiments/get-by-name"], + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/create"], + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/update"], + *mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/log-batch"], + mock_responses["https://test.sagemaker.aws/api/2.0/mlflow/runs/terminate"], + ] + + metrics = [Metric("loss", 0.1, 1, 0)] + params = [Param("learning_rate", "0.01")] + tags = {"tag1": "value1"} + + log_to_mlflow(metrics, params, tags) + + assert mock_request.call_count == 7 # Total number of API calls + + +@patch("sagemaker.mlflow.forward_sagemaker_metrics.get_training_job_details") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.create_metric_queries") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.get_metric_data") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.prepare_mlflow_metrics") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.prepare_mlflow_params") +@patch("sagemaker.mlflow.forward_sagemaker_metrics.log_to_mlflow") +def test_log_sagemaker_job_to_mlflow( + mock_log_to_mlflow, + mock_prepare_params, + mock_prepare_metrics, + mock_get_metric_data, + mock_create_queries, + mock_get_job_details, +): + mock_get_job_details.return_value = { + "HyperParameters": {"learning_rate": "0.01"}, + "AlgorithmSpecification": {"MetricDefinitions": [{"Name": "loss"}]}, + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:123456789012:training-job/test-job", + } + mock_create_queries.return_value = [{"MetricName": "loss"}] + mock_get_metric_data.return_value = {"MetricQueryResults": []} + mock_prepare_metrics.return_value = ([], {}) + mock_prepare_params.return_value = ([], {}) + + log_sagemaker_job_to_mlflow("test-job") + + mock_get_job_details.assert_called_once() + mock_create_queries.assert_called_once() + mock_get_metric_data.assert_called_once() + mock_prepare_metrics.assert_called_once() + mock_prepare_params.assert_called_once() + mock_log_to_mlflow.assert_called_once() + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/unit/sagemaker/mlflow/test_tracking_server.py b/tests/unit/sagemaker/mlflow/test_tracking_server.py new file mode 100644 index 0000000000..1fc4943f16 --- /dev/null +++ b/tests/unit/sagemaker/mlflow/test_tracking_server.py @@ -0,0 +1,42 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +from __future__ import absolute_import +from sagemaker.mlflow.tracking_server import generate_mlflow_presigned_url + + +def test_generate_presigned_url(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_presigned_mlflow_tracking_server_url.return_value = { + "AuthorizedUrl": "https://t-wo.example.com", + } + url = generate_mlflow_presigned_url( + "w", + expires_in_seconds=10, + session_expiration_duration_in_seconds=5, + sagemaker_session=sagemaker_session, + ) + client.create_presigned_mlflow_tracking_server_url.assert_called_with( + TrackingServerName="w", ExpiresInSeconds=10, SessionExpirationDurationInSeconds=5 + ) + assert url == "https://t-wo.example.com" + + +def test_generate_presigned_url_minimal(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_presigned_mlflow_tracking_server_url.return_value = { + "AuthorizedUrl": "https://t-wo.example.com", + } + url = generate_mlflow_presigned_url("w", sagemaker_session=sagemaker_session) + client.create_presigned_mlflow_tracking_server_url.assert_called_with(TrackingServerName="w") + assert url == "https://t-wo.example.com" diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 953cbe775c..ef8b5e6af5 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -23,6 +23,7 @@ from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.explainer import ExplainerConfig from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +from sagemaker.enums import EndpointType from tests.unit.sagemaker.inference_recommender.constants import ( DESCRIBE_COMPILATION_JOB_RESPONSE, DESCRIBE_MODEL_PACKAGE_RESPONSE, @@ -114,7 +115,11 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME, @@ -125,6 +130,8 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, + inference_ami_version=None, ) sagemaker_session.create_model.assert_called_with( @@ -174,6 +181,8 @@ def test_deploy_accelerator_type( accelerator_type=ACCELERATOR_TYPE, tags=None, serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME, @@ -184,6 +193,8 @@ def test_deploy_accelerator_type( volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, + inference_ami_version=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -293,6 +304,8 @@ def test_deploy_tags(create_sagemaker_model, production_variant, name_from_base, accelerator_type=None, tags=tags, serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( name=ENDPOINT_NAME, @@ -496,6 +509,8 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model, accelerator_type=None, tags=None, serverless_inference_config=serverless_inference_config, + accept_eula=None, + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME, @@ -506,6 +521,8 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model, volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, + inference_ami_version=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -927,7 +944,11 @@ def test_deploy_customized_volume_size_and_timeout( assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME, @@ -938,6 +959,8 @@ def test_deploy_customized_volume_size_and_timeout( volume_size=volume_size_gb, model_data_download_timeout=model_data_download_timeout_sec, container_startup_health_check_timeout=startup_health_check_timeout_sec, + routing_config=None, + inference_ami_version=None, ) sagemaker_session.create_model.assert_called_with( @@ -987,6 +1010,8 @@ def test_deploy_with_resources(sagemaker_session, name_from_base, production_var volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + routing_config=None, + inference_ami_version=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( name=name_from_base(MODEL_NAME), @@ -1027,3 +1052,143 @@ def test_deploy_with_name_and_resources(sagemaker_session): async_inference_config_dict=None, live_logging=False, ) + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) +@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) +def test_deploy_with_update_endpoint(production_variant, name_from_base, sagemaker_session): + model = Model( + MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session + ) + + # Mock the create_endpoint_config to return a specific config name + endpoint_config_name = "test-config-name" + sagemaker_session.create_endpoint_config.return_value = endpoint_config_name + + # Test update_endpoint=True scenario + endpoint_name = "existing-endpoint" + model.deploy( + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + endpoint_name=endpoint_name, + update_endpoint=True, + ) + + # Verify create_endpoint_config is called with correct parameters + sagemaker_session.create_endpoint_config.assert_called_with( + name=MODEL_NAME, + model_name=MODEL_NAME, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=None, + tags=None, + kms_key=None, + data_capture_config_dict=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + explainer_config_dict=None, + async_inference_config_dict=None, + serverless_inference_config_dict=None, + routing_config=None, + inference_ami_version=None, + ) + + # Verify update_endpoint is called with correct parameters + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, endpoint_config_name) + + # Test update_endpoint with serverless config + serverless_inference_config = ServerlessInferenceConfig() + serverless_inference_config_dict = { + "MemorySizeInMB": 2048, + "MaxConcurrency": 5, + } + model.deploy( + endpoint_name=endpoint_name, + update_endpoint=True, + serverless_inference_config=serverless_inference_config, + ) + + sagemaker_session.create_endpoint_config.assert_called_with( + name=MODEL_NAME, + model_name=MODEL_NAME, + initial_instance_count=None, + instance_type=None, + accelerator_type=None, + tags=None, + kms_key=None, + data_capture_config_dict=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + explainer_config_dict=None, + async_inference_config_dict=None, + serverless_inference_config_dict=serverless_inference_config_dict, + routing_config=None, + inference_ami_version=None, + ) + + # Verify update_endpoint is called with the new config + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, endpoint_config_name) + + # Test update_endpoint with async inference config + async_inference_config = AsyncInferenceConfig( + output_path="s3://bucket/output", failure_path="s3://bucket/failure" + ) + async_inference_config_dict = { + "OutputConfig": { + "S3OutputPath": "s3://bucket/output", + "S3FailurePath": "s3://bucket/failure", + }, + } + model.deploy( + endpoint_name=endpoint_name, + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + update_endpoint=True, + async_inference_config=async_inference_config, + ) + + sagemaker_session.create_endpoint_config.assert_called_with( + name=MODEL_NAME, + model_name=MODEL_NAME, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=None, + tags=None, + kms_key=None, + data_capture_config_dict=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + explainer_config_dict=None, + async_inference_config_dict=async_inference_config_dict, + serverless_inference_config_dict=None, + routing_config=None, + inference_ami_version=None, + ) + + # Verify update_endpoint is called with the new config + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, endpoint_config_name) + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) +def test_deploy_with_update_endpoint_inference_component(production_variant, sagemaker_session): + model = Model( + MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session + ) + + # Test that updating endpoint with inference component raises error + with pytest.raises( + ValueError, match="Currently update_endpoint is supported for single model endpoints" + ): + model.deploy( + endpoint_name="test-endpoint", + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + update_endpoint=True, + resources=RESOURCES, + endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, + ) diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py index d41dd6f821..432d90bd37 100644 --- a/tests/unit/sagemaker/model/test_framework_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -511,6 +511,20 @@ def test_is_repack_with_code_location(repack_model, sagemaker_session): assert not model.is_repack() +@patch("sagemaker.utils.repack_model") +def test_is_repack_with_none_type(repack_model, sagemaker_session): + """Test is_repack() returns a boolean value when source_dir and entry_point are None""" + + model = FrameworkModel( + role=ROLE, + sagemaker_session=sagemaker_session, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + + assert model.is_repack() is False + + @patch("sagemaker.git_utils.git_clone_repo") @patch("sagemaker.model.fw_utils.tar_and_upload_dir") def test_is_repack_with_git_config(tar_and_upload_dir, git_clone_repo, sagemaker_session): diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index c0b18a3eb3..3d498dfc59 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -287,7 +287,11 @@ def test_create_sagemaker_model(prepare_container_def, sagemaker_session): model._create_sagemaker_model() prepare_container_def.assert_called_with( - None, accelerator_type=None, serverless_inference_config=None, accept_eula=None + None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) sagemaker_session.create_model.assert_called_with( name=MODEL_NAME, @@ -305,7 +309,11 @@ def test_create_sagemaker_model_instance_type(prepare_container_def, sagemaker_s model._create_sagemaker_model(INSTANCE_TYPE) prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) @@ -321,6 +329,7 @@ def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemake accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ) @@ -336,6 +345,7 @@ def test_create_sagemaker_model_with_eula(prepare_container_def, sagemaker_sessi accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=True, + model_reference_arn=None, ) @@ -351,6 +361,7 @@ def test_create_sagemaker_model_with_eula_false(prepare_container_def, sagemaker accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=False, + model_reference_arn=None, ) @@ -948,6 +959,56 @@ def test_all_framework_models_inference_component_based_endpoint_deploy_path( sagemaker_session.create_model.reset_mock() +@patch("sagemaker.utils.repack_model") +@patch("sagemaker.fw_utils.tar_and_upload_dir") +def test_sharded_model_force_inference_component_based_endpoint_deploy_path( + repack_model, tar_and_uload_dir, sagemaker_session +): + framework_model_classes_to_kwargs = { + HuggingFaceModel: { + "pytorch_version": "1.7.1", + "py_version": "py36", + "transformers_version": "4.6.1", + }, + } + + sagemaker_session.settings = SessionSettings(include_jumpstart_tags=False) + + source_dir = "s3://blah/blah/blah" + for framework_model_class, kwargs in framework_model_classes_to_kwargs.items(): + test_sharded_model = framework_model_class( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + model_data=source_dir, + **kwargs, + ) + test_sharded_model._is_sharded_model = True + test_sharded_model.deploy( + instance_type="ml.m2.xlarge", + initial_instance_count=INSTANCE_COUNT, + endpoint_type=EndpointType.MODEL_BASED, + resources=ResourceRequirements( + requests={ + "num_accelerators": 1, + "memory": 8192, + "copies": 1, + }, + limits={}, + ), + ) + + # Verified inference component based endpoint and inference component creation + # path + sagemaker_session.endpoint_in_service_or_not.assert_called_once() + sagemaker_session.create_model.assert_called_once() + sagemaker_session.create_inference_component.assert_called_once() + + sagemaker_session.create_inference_component.reset_mock() + sagemaker_session.endpoint_in_service_or_not.reset_mock() + sagemaker_session.create_model.reset_mock() + + @patch("sagemaker.utils.repack_model") def test_repack_code_location_with_key_prefix(repack_model, sagemaker_session): @@ -985,6 +1046,20 @@ def test_is_repack_with_code_location(repack_model, sagemaker_session): assert model.is_repack() +@patch("sagemaker.utils.repack_model") +def test_is_repack_with_none_type(repack_model, sagemaker_session): + """Test is_repack() returns a boolean value when source_dir and entry_point are None""" + + model = Model( + role=ROLE, + sagemaker_session=sagemaker_session, + image_uri=IMAGE_URI, + model_data=MODEL_DATA, + ) + + assert model.is_repack() is False + + @patch("sagemaker.git_utils.git_clone_repo") @patch("sagemaker.model.fw_utils.tar_and_upload_dir") def test_is_repack_with_git_config(tar_and_upload_dir, git_clone_repo, sagemaker_session): @@ -1421,3 +1496,47 @@ def test_model_source( ) assert model_1._get_model_uri() == "s3://tmybuckaet" + + +@patch("sagemaker.utils.repack_model") +@patch("sagemaker.fw_utils.tar_and_upload_dir") +def test_deploy_sharded_model_with_cpus_requested_raises_warning( + repack_model, tar_and_upload_dir, sagemaker_session +): + framework_model_classes_to_kwargs = { + HuggingFaceModel: { + "pytorch_version": "1.7.1", + "py_version": "py36", + "transformers_version": "4.6.1", + }, + } + + sagemaker_session.settings = SessionSettings(include_jumpstart_tags=False) + + source_dir = "s3://blah/blah/blah" + for framework_model_class, kwargs in framework_model_classes_to_kwargs.items(): + test_sharded_model = framework_model_class( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + model_data=source_dir, + **kwargs, + ) + test_sharded_model._is_sharded_model = True + from unittest import mock + + with mock.patch("sagemaker.model.logger") as mock_logger: + mock_logger.warning.reset_mock() + test_sharded_model.deploy( + instance_type="ml.m2.xlarge", + initial_instance_count=INSTANCE_COUNT, + endpoint_type=EndpointType.MODEL_BASED, + resources=ResourceRequirements( + requests={"num_accelerators": 1, "memory": 8192, "copies": 1, "num_cpus": 1}, + limits={}, + ), + ) + mock_logger.warning.assert_called_once_with( + "NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker " + "Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`." + ) diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index 9bfc830a75..85649a8d24 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -19,7 +19,9 @@ import sagemaker from sagemaker.model import ModelPackage -from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum +from sagemaker.model_card.model_card import ModelCard, ModelOverview +from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum, ModelCardStatusEnum +from sagemaker.model_life_cycle import ModelLifeCycle MODEL_PACKAGE_VERSIONED_ARN = ( "arn:aws:sagemaker:us-west-2:001234567890:model-package/testmodelgroup/1" @@ -56,6 +58,10 @@ "ModelPackageStatus": "Completed", "ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502", "CertifyForMarketplace": False, + "ModelCard": { + "ModelCardStatus": "Draft", + "ModelCardContent": '{"model_overview": {"model_creator": "updatedCreator", "model_artifact": []}}', + }, } MODEL_DATA = { @@ -442,3 +448,85 @@ def test_update_source_uri(sagemaker_session): sagemaker_session.sagemaker_client.update_model_package.assert_called_with( ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, SourceUri=source_uri ) + + +def test_update_model_card(sagemaker_session): + model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) + + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=model_package_response + ) + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + + update_my_card = ModelCard( + name="UpdateTestName", + sagemaker_session=sagemaker_session, + status=ModelCardStatusEnum.PENDING_REVIEW, + ) + model_package.update_model_card(update_my_card) + update_my_card_req = update_my_card._create_request_args() + del update_my_card_req["ModelCardName"] + del update_my_card_req["Content"] + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelCard=update_my_card_req + ) + + model_overview = ModelOverview( + model_creator="UpdatedNewCreator", + ) + update_my_card_1 = ModelCard( + name="UpdateTestName", + sagemaker_session=sagemaker_session, + status=ModelCardStatusEnum.DRAFT, + model_overview=model_overview, + ) + model_package.update_model_card(update_my_card_1) + update_my_card_req_1 = update_my_card_1._create_request_args() + del update_my_card_req_1["ModelCardName"] + del update_my_card_req_1["ModelCardStatus"] + update_my_card_req_1["ModelCardContent"] = update_my_card_req_1["Content"] + del update_my_card_req_1["Content"] + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelCard=update_my_card_req_1 + ) + + +def test_update_model_life_cycle(sagemaker_session): + model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) + + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=model_package_response + ) + model_package = ModelPackage( + role="role", + model_package_arn=MODEL_PACKAGE_VERSIONED_ARN, + sagemaker_session=sagemaker_session, + ) + + update_model_life_cycle = ModelLifeCycle( + stage="Development", + stage_status="Approved", + stage_description="Approving for Development", + ) + update_model_life_cycle_req = update_model_life_cycle._to_request_dict() + model_package.update_model_life_cycle(update_model_life_cycle_req) + + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelLifeCycle=update_model_life_cycle_req + ) + + update_model_life_cycle1 = ModelLifeCycle( + stage="Staging", + stage_status="In-Progress", + stage_description="Sending for Staging Verification", + ) + update_model_life_cycle_req1 = update_model_life_cycle1._to_request_dict() + model_package.update_model_life_cycle(update_model_life_cycle_req1) + + sagemaker_session.sagemaker_client.update_model_package.assert_called_with( + ModelPackageArn=MODEL_PACKAGE_VERSIONED_ARN, ModelLifeCycle=update_model_life_cycle_req1 + ) diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 8ec9478d8a..8593c599fe 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -54,6 +54,8 @@ def test_jumpstart_common_model_uri( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -72,6 +74,8 @@ def test_jumpstart_common_model_uri( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -91,6 +95,8 @@ def test_jumpstart_common_model_uri( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -110,6 +116,8 @@ def test_jumpstart_common_model_uri( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -176,4 +184,7 @@ def test_jumpstart_artifact_bucket_override( model_id="pytorch-ic-mobilenet-v2", model_version="*", ) - assert uri == "s3://some-cool-bucket-name/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" + assert ( + uri + == "s3://some-cool-bucket-name/pytorch-training/v2.0.0/train-pytorch-ic-mobilenet-v2.tar.gz" + ) diff --git a/tests/unit/sagemaker/modules/__init__.py b/tests/unit/sagemaker/modules/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/modules/local_core/test_local_container.py b/tests/unit/sagemaker/modules/local_core/test_local_container.py new file mode 100644 index 0000000000..88f6f81707 --- /dev/null +++ b/tests/unit/sagemaker/modules/local_core/test_local_container.py @@ -0,0 +1,179 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""LocalContainer Tests.""" +from __future__ import absolute_import +import os +import shutil + +import pytest + +from sagemaker.modules.configs import Channel, FileSystemDataSource +from sagemaker.modules.local_core.local_container import DOCKER_COMPOSE_FILENAME, _LocalContainer +from sagemaker_core.shapes import DataSource + +TRAINING_JOB_NAME = "job_name" +INSTANCE_TYPE = "ml.m5.xlarge" +TEST_IMAGE_NAME = "test_image" +CONTAINER_ROOT = os.getcwd() +CONTAINER_ENTRYPOINT = ["/bin/bash"] +CONTAINER_ARGUMENTS = [ + "-c", + ( + "chmod +x /opt/ml/input/data/sm_drivers/sm_train.sh " + + "&& /opt/ml/input/data/sm_drivers/sm_train.sh" + ), +] + + +@pytest.fixture +def input_data_config(): + return [ + Channel( + channel_name="local_input_channel", + data_source=DataSource( + file_system_data_source=FileSystemDataSource.model_construct( + directory_path=CONTAINER_ROOT, + file_system_type="EFS", + ), + ), + input_mode="File", + ) + ] + + +@pytest.fixture +def hyper_parameters(): + return { + "epochs": "1", + "optimizer": "adamw_torch", + } + + +@pytest.fixture +def shared_volumes(): + return [ + f"{CONTAINER_ROOT}/model:/opt/ml/model", + f"{CONTAINER_ROOT}:/opt/ml/input/data/local_input_channel", + ] + + +@pytest.fixture +def environment(): + return { + "SM_OUTPUT_DIR": "/opt/ml/output", + "SM_INPUT_CONFIG_DIR": "/opt/ml/input/config", + "SM_OUTPUT_DATA_DIR": "/opt/ml/output/data", + } + + +@pytest.fixture +def local_container(input_data_config, hyper_parameters, environment): + container = _LocalContainer( + training_job_name=TRAINING_JOB_NAME, + instance_type=INSTANCE_TYPE, + instance_count=2, + image=TEST_IMAGE_NAME, + container_root=CONTAINER_ROOT, + is_studio=False, + input_data_config=input_data_config, + hyper_parameters=hyper_parameters, + environment=environment, + sagemaker_session=None, + container_entrypoint=CONTAINER_ENTRYPOINT, + container_arguments=CONTAINER_ARGUMENTS, + ) + return container + + +def expected_host_config(shared_volumes, host): + return { + "entrypoint": [ + "/bin/bash", + "-c", + "chmod +x /opt/ml/input/data/sm_drivers/sm_train.sh && " + "/opt/ml/input/data/sm_drivers/sm_train.sh", + ], + "environment": [ + "SM_OUTPUT_DIR=/opt/ml/output", + "SM_INPUT_CONFIG_DIR=/opt/ml/input/config", + "SM_OUTPUT_DATA_DIR=/opt/ml/output/data", + ], + "image": "test_image", + "networks": { + "sagemaker-local": { + "aliases": [ + host, + ], + }, + }, + "volumes": shared_volumes + + [ + f"{CONTAINER_ROOT}/{host}/output:/opt/ml/output", + f"{CONTAINER_ROOT}/{host}/output/data:/opt/ml/output/data", + f"{CONTAINER_ROOT}/{host}/input:/opt/ml/input", + ], + } + + +def expected_compose_file(shared_volumes, hosts): + return { + "networks": { + "sagemaker-local": { + "name": "sagemaker-local", + }, + }, + "services": {host: expected_host_config(shared_volumes, host) for host in hosts}, + } + + +def test_write_config_files(local_container, input_data_config, hyper_parameters): + config_path = os.path.join(local_container.container_root, "algo-1", "input", "config") + os.makedirs(config_path, exist_ok=True) + local_container._write_config_files( + host="algo-1", + input_data_config=input_data_config, + hyper_parameters=hyper_parameters, + ) + + assert os.path.exists(os.path.join(config_path, "hyperparameters.json")) + assert os.path.exists(os.path.join(config_path, "resourceconfig.json")) + assert os.path.exists(os.path.join(config_path, "inputdataconfig.json")) + + shutil.rmtree(config_path) + + +def test_prepare_training_volumes( + local_container, input_data_config, hyper_parameters, shared_volumes +): + data_dir = os.path.join(local_container.container_root, "input", "data") + output = local_container._prepare_training_volumes( + data_dir, input_data_config, hyper_parameters + ) + + assert output == shared_volumes + + +def test_create_docker_host(local_container, environment, shared_volumes): + host = "algo-1" + output = local_container._create_docker_host(host, environment, shared_volumes) + assert output == expected_host_config(shared_volumes, host) + + +def test_generate_compose_file(local_container, environment, shared_volumes): + output = local_container._generate_compose_file(environment, shared_volumes) + + assert output == expected_compose_file(shared_volumes, local_container.hosts) + + docker_compose_path = os.path.join(local_container.container_root, DOCKER_COMPOSE_FILENAME) + assert os.path.exists(docker_compose_path) + os.remove(docker_compose_path) diff --git a/tests/unit/sagemaker/modules/test_utils.py b/tests/unit/sagemaker/modules/test_utils.py new file mode 100644 index 0000000000..efe43f1792 --- /dev/null +++ b/tests/unit/sagemaker/modules/test_utils.py @@ -0,0 +1,140 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utils Tests.""" +from __future__ import absolute_import + +import pytest + +from tests.unit import DATA_DIR +from sagemaker.modules.utils import ( + _is_valid_s3_uri, + _is_valid_path, + _get_unique_name, + _get_repo_name_from_image, +) + + +@pytest.mark.parametrize( + "test_case", + [ + { + "path": "s3://bucket/key", + "path_type": "Any", + "expected": True, + }, + { + "path": "s3://bucket/key", + "path_type": "File", + "expected": True, + }, + { + "path": "s3://bucket/key/", + "path_type": "Directory", + "expected": True, + }, + { + "path": "s3://bucket/key/", + "path_type": "File", + "expected": False, + }, + { + "path": "s3://bucket/key", + "path_type": "Directory", + "expected": False, + }, + { + "path": "/bucket/key", + "path_type": "Any", + "expected": False, + }, + ], +) +def test_is_valid_s3_uri(test_case): + assert _is_valid_s3_uri(test_case["path"], test_case["path_type"]) == test_case["expected"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "path": DATA_DIR, + "path_type": "Any", + "expected": True, + }, + { + "path": DATA_DIR, + "path_type": "Directory", + "expected": True, + }, + { + "path": f"{DATA_DIR}/dummy_input.txt", + "path_type": "File", + "expected": True, + }, + { + "path": f"{DATA_DIR}/dummy_input.txt", + "path_type": "Directory", + "expected": False, + }, + { + "path": f"{DATA_DIR}/non_existent", + "path_type": "Any", + "expected": False, + }, + ], +) +def test_is_valid_path(test_case): + assert _is_valid_path(test_case["path"], test_case["path_type"]) == test_case["expected"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "base": "test", + "max_length": 5, + }, + { + "base": "1111111111" * 7, + "max_length": None, + }, + ], +) +def test_get_unique_name(test_case): + assert ( + len(_get_unique_name(test_case["base"], test_case.get("max_length"))) + <= test_case["max_length"] + if test_case.get("max_length") + else 63 + ) + + +@pytest.mark.parametrize( + "test_case", + [ + { + "image": "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:latest", + "expected": "my-custom-image", + }, + { + "image": "my-custom-image:latest", + "expected": "my-custom-image", + }, + { + "image": "public.ecr.aws/docker/library/my-custom-image:latest", + "expected": "my-custom-image", + }, + ], +) +def test_get_repo_name_from_image(test_case): + assert _get_repo_name_from_image(test_case["image"]) == test_case["expected"] diff --git a/tests/unit/sagemaker/modules/train/__init__.py b/tests/unit/sagemaker/modules/train/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/modules/train/container_drivers/__init__.py b/tests/unit/sagemaker/modules/train/container_drivers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py new file mode 100644 index 0000000000..fe4fa08825 --- /dev/null +++ b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py @@ -0,0 +1,213 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Enviornment Variable Script Unit Tests.""" +from __future__ import absolute_import + +import os +import io +import logging + +from unittest.mock import patch + +from sagemaker.modules.train.container_drivers.scripts.environment import ( + set_env, + log_env_variables, + HIDDEN_VALUE, +) +from sagemaker.modules.train.container_drivers.common.utils import safe_serialize, safe_deserialize + +RESOURCE_CONFIG = dict( + current_host="algo-1", + hosts=["algo-1", "algo-2", "algo-3"], + current_group_name="train1", + current_instance_type="ml.p3.16xlarge", + instance_groups=[ + dict( + instance_group_name="train1", + instance_type="ml.p3.16xlarge", + hosts=["algo-1", "algo-2"], + ), + dict( + instance_group_name="train2", + instance_type="ml.p3.8xlarge", + hosts=["algo-3"], + ), + ], + network_interface_name="eth0", +) + +INPUT_DATA_CONFIG = { + "train": { + "ContentType": "trainingContentType", + "TrainingInputMode": "File", + "S3DistributionType": "FullyReplicated", + "RecordWrapperType": "None", + }, + "validation": { + "TrainingInputMode": "File", + "S3DistributionType": "FullyReplicated", + "RecordWrapperType": "None", + }, +} + +USER_HYPERPARAMETERS = { + "batch_size": 32, + "learning_rate": 0.001, + "hosts": ["algo-1", "algo-2"], + "mp_parameters": { + "microbatches": 2, + "partitions": 2, + "pipeline": "interleaved", + "optimize": "memory", + "horovod": True, + }, +} + +SOURCE_CODE = { + "source_dir": "code", + "entry_script": "train.py", +} + +DISTRIBUTED_CONFIG = { + "process_count_per_node": 2, +} + +OUTPUT_FILE = os.path.join(os.path.dirname(__file__), "sm_training.env") + +# flake8: noqa +EXPECTED_ENV = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_LOG_LEVEL='20' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_SOURCE_DIR='/opt/ml/input/data/code' +export SM_ENTRY_SCRIPT='train.py' +export SM_DISTRIBUTED_DRIVER_DIR='/opt/ml/input/data/sm_drivers/distributed_drivers' +export SM_DISTRIBUTED_CONFIG='{"process_count_per_node": 2}' +export SM_CHANNEL_TRAIN='/opt/ml/input/data/train' +export SM_CHANNEL_VALIDATION='/opt/ml/input/data/validation' +export SM_CHANNELS='["train", "validation"]' +export SM_HP_BATCH_SIZE='32' +export SM_HP_LEARNING_RATE='0.001' +export SM_HP_HOSTS='["algo-1", "algo-2"]' +export SM_HP_MP_PARAMETERS='{"microbatches": 2, "partitions": 2, "pipeline": "interleaved", "optimize": "memory", "horovod": true}' +export SM_HPS='{"batch_size": 32, "learning_rate": 0.001, "hosts": ["algo-1", "algo-2"], "mp_parameters": {"microbatches": 2, "partitions": 2, "pipeline": "interleaved", "optimize": "memory", "horovod": true}}' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.p3.16xlarge' +export SM_HOSTS='["algo-1", "algo-2", "algo-3"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='3' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='8' +export SM_NUM_GPUS='0' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3"], "current_group_name": "train1", "current_instance_type": "ml.p3.16xlarge", "instance_groups": [{"instance_group_name": "train1", "instance_type": "ml.p3.16xlarge", "hosts": ["algo-1", "algo-2"]}, {"instance_group_name": "train2", "instance_type": "ml.p3.8xlarge", "hosts": ["algo-3"]}], "network_interface_name": "eth0"}' +export SM_INPUT_DATA_CONFIG='{"train": {"ContentType": "trainingContentType", "TrainingInputMode": "File", "S3DistributionType": "FullyReplicated", "RecordWrapperType": "None"}, "validation": {"TrainingInputMode": "File", "S3DistributionType": "FullyReplicated", "RecordWrapperType": "None"}}' +export SM_TRAINING_ENV='{"channel_input_dirs": {"train": "/opt/ml/input/data/train", "validation": "/opt/ml/input/data/validation"}, "current_host": "algo-1", "current_instance_type": "ml.p3.16xlarge", "hosts": ["algo-1", "algo-2", "algo-3"], "master_addr": "algo-1", "master_port": 7777, "hyperparameters": {"batch_size": 32, "learning_rate": 0.001, "hosts": ["algo-1", "algo-2"], "mp_parameters": {"microbatches": 2, "partitions": 2, "pipeline": "interleaved", "optimize": "memory", "horovod": true}}, "input_data_config": {"train": {"ContentType": "trainingContentType", "TrainingInputMode": "File", "S3DistributionType": "FullyReplicated", "RecordWrapperType": "None"}, "validation": {"TrainingInputMode": "File", "S3DistributionType": "FullyReplicated", "RecordWrapperType": "None"}}, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "log_level": 20, "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 8, "num_gpus": 0, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3"], "current_group_name": "train1", "current_instance_type": "ml.p3.16xlarge", "instance_groups": [{"instance_group_name": "train1", "instance_type": "ml.p3.16xlarge", "hosts": ["algo-1", "algo-2"]}, {"instance_group_name": "train2", "instance_type": "ml.p3.8xlarge", "hosts": ["algo-3"]}], "network_interface_name": "eth0"}}' +""" + + +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.read_source_code_json", + return_value=SOURCE_CODE, +) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.read_distributed_json", + return_value=DISTRIBUTED_CONFIG, +) +@patch("sagemaker.modules.train.container_drivers.scripts.environment.num_cpus", return_value=8) +@patch("sagemaker.modules.train.container_drivers.scripts.environment.num_gpus", return_value=0) +@patch("sagemaker.modules.train.container_drivers.scripts.environment.num_neurons", return_value=0) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.safe_serialize", + side_effect=safe_serialize, +) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.safe_deserialize", + side_effect=safe_deserialize, +) +def test_set_env( + mock_safe_deserialize, + mock_safe_serialize, + mock_num_neurons, + mock_num_gpus, + mock_num_cpus, + mock_read_distributed_json, + mock_read_source_code_json, +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + set_env( + resource_config=RESOURCE_CONFIG, + input_data_config=INPUT_DATA_CONFIG, + hyperparameters_config=USER_HYPERPARAMETERS, + output_file=OUTPUT_FILE, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + mock_read_distributed_json.assert_called_once() + mock_read_source_code_json.assert_called_once() + + with open(OUTPUT_FILE, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + os.remove(OUTPUT_FILE) + assert not os.path.exists(OUTPUT_FILE) + + +@patch.dict(os.environ, {"SECRET_TOKEN": "122345678", "CLEAR_DATA": "123456789"}, clear=True) +def test_log_env_variables(): + log_stream = io.StringIO() + handler = logging.StreamHandler(log_stream) + + logger = logging.getLogger("sagemaker.modules.train.container_drivers.scripts.environment") + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + env_vars = { + "SM_MODEL_DIR": "/opt/ml/model", + "SM_INPUT_DIR": "/opt/ml/input", + "SM_HPS": {"batch_size": 32, "learning_rate": 0.001, "access_token": "123456789"}, + "SM_HP_BATCH_SIZE": 32, + "SM_HP_LEARNING_RATE": 0.001, + "SM_HP_ACCESS_TOKEN": "123456789", + } + log_env_variables(env_vars_dict=env_vars) + + log_output = log_stream.getvalue() + + assert f"SECRET_TOKEN={HIDDEN_VALUE}" in log_output + assert "CLEAR_DATA=123456789" in log_output + assert "SM_MODEL_DIR=/opt/ml/model" in log_output + assert ( + f'SM_HPS={{"batch_size": 32, "learning_rate": 0.001, "access_token": "{HIDDEN_VALUE}"}}' + in log_output + ) + assert "SM_HP_BATCH_SIZE=32" in log_output + assert "SM_HP_LEARNING_RATE=0.001" in log_output + assert f"SM_HP_ACCESS_TOKEN={HIDDEN_VALUE}" in log_output + + +def _remove_extra_lines(string): + """Removes extra blank lines from a string.""" + return "\n".join([line for line in string.splitlines() if line.strip()]) diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py new file mode 100644 index 0000000000..bf51db8285 --- /dev/null +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py @@ -0,0 +1,163 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""MPI Driver Unit Tests.""" +from __future__ import absolute_import + +import os +import sys +import json + +from unittest.mock import patch, MagicMock + +sys.modules["utils"] = MagicMock() +sys.modules["mpi_utils"] = MagicMock() + +from sagemaker.modules.train.container_drivers.distributed_drivers import mpi_driver # noqa: E402 + + +DUMMY_MPI_COMMAND = [ + "mpirun", + "--host", + "algo-1,algo-2", + "-np", + "2", + "--verbose", + "-x", + "ENV_VAR1", + "python", + "-m", + "mpi4py", + "-m", + "script.py", +] + +DUMMY_DISTRIBUTED = { + "process_count_per_node": 2, + "mpi_additional_options": [ + "--verbose", + "-x", + "ENV_VAR1", + ], +} + + +@patch.dict( + os.environ, + { + "SM_CURRENT_HOST": "algo-2", + "SM_HOSTS": '["algo-1", "algo-2"]', + "SM_MASTER_ADDR": "algo-1", + "SM_HOST_COUNT": "2", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "/opt/ml/input/data/code/script.py", + }, +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") +def test_mpi_driver_worker( + mock_execute_commands, + mock_get_mpirun_command, + mock_hyperparameters_to_cli_args, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, + mock_start_sshd_daemon, + mock_write_env_vars_to_file, +): + mock_hyperparameters_to_cli_args.return_value = [] + + mpi_driver.main() + + mock_write_env_vars_to_file.assert_called_once() + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_worker_node.assert_called_once() + + mock_bootstrap_master_node.assert_not_called() + mock_get_mpirun_command.assert_not_called() + mock_execute_commands.assert_not_called() + + +@patch.dict( + os.environ, + { + "SM_CURRENT_HOST": "algo-1", + "SM_HOSTS": '["algo-1", "algo-2"]', + "SM_MASTER_ADDR": "algo-1", + "SM_HOST_COUNT": "2", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", + }, +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_env_vars_to_file" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.start_sshd_daemon") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_master_node" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.bootstrap_worker_node" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_process_count") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.hyperparameters_to_cli_args" +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.get_mpirun_command" +) +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.execute_commands") +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.mpi_driver.write_status_file_to_workers" +) +def test_mpi_driver_master( + mock_write_status_file_to_workers, + mock_execute_commands, + mock_get_mpirun_command, + mock_hyperparameters_to_cli_args, + mock_get_process_count, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, + mock_start_sshd_daemon, + mock_write_env_vars_to_file, +): + mock_hyperparameters_to_cli_args.return_value = [] + mock_get_mpirun_command.return_value = DUMMY_MPI_COMMAND + mock_get_process_count.return_value = 2 + mock_execute_commands.return_value = (0, "") + + mpi_driver.main() + + mock_write_env_vars_to_file.assert_called_once() + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_master_node.assert_called_once() + mock_get_mpirun_command.assert_called_once() + mock_execute_commands.assert_called_once_with(DUMMY_MPI_COMMAND) + mock_write_status_file_to_workers.assert_called_once() + + mock_bootstrap_worker_node.assert_not_called() diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py new file mode 100644 index 0000000000..35208d708a --- /dev/null +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py @@ -0,0 +1,113 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""MPI Utils Unit Tests.""" +from __future__ import absolute_import + +import subprocess +from unittest.mock import Mock, patch + +import paramiko +import pytest + +# Mock the utils module before importing mpi_utils +mock_utils = Mock() +mock_utils.logger = Mock() +mock_utils.SM_EFA_NCCL_INSTANCES = [] +mock_utils.SM_EFA_RDMA_INSTANCES = [] +mock_utils.get_python_executable = Mock(return_value="/usr/bin/python") + +with patch.dict("sys.modules", {"utils": mock_utils}): + from sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils import ( + CustomHostKeyPolicy, + _can_connect, + write_status_file_to_workers, + ) + +TEST_HOST = "algo-1" +TEST_WORKER = "algo-2" +TEST_STATUS_FILE = "/tmp/test-status" + + +def test_custom_host_key_policy_valid_hostname(): + """Test CustomHostKeyPolicy accepts algo- prefixed hostnames.""" + policy = CustomHostKeyPolicy() + mock_client = Mock() + mock_key = Mock() + mock_key.get_name.return_value = "ssh-rsa" + + policy.missing_host_key(mock_client, "algo-1", mock_key) + + mock_client.get_host_keys.assert_called_once() + mock_client.get_host_keys().add.assert_called_once_with("algo-1", "ssh-rsa", mock_key) + + +def test_custom_host_key_policy_invalid_hostname(): + """Test CustomHostKeyPolicy rejects non-algo prefixed hostnames.""" + policy = CustomHostKeyPolicy() + mock_client = Mock() + mock_key = Mock() + + with pytest.raises(paramiko.SSHException) as exc_info: + policy.missing_host_key(mock_client, "invalid-1", mock_key) + + assert "Unknown host key for invalid-1" in str(exc_info.value) + mock_client.get_host_keys.assert_not_called() + + +@patch("paramiko.SSHClient") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") +def test_can_connect_success(mock_logger, mock_ssh_client): + """Test successful SSH connection.""" + mock_client = Mock() + mock_ssh_client.return_value.__enter__.return_value = mock_client + mock_client.connect.return_value = None # Successful connection + + result = _can_connect(TEST_HOST) + + assert result is True + mock_client.load_system_host_keys.assert_called_once() + mock_client.set_missing_host_key_policy.assert_called_once() + mock_client.connect.assert_called_once_with(TEST_HOST, port=22) + + +@patch("paramiko.SSHClient") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") +def test_can_connect_failure(mock_logger, mock_ssh_client): + """Test SSH connection failure.""" + mock_client = Mock() + mock_ssh_client.return_value.__enter__.return_value = mock_client + mock_client.connect.side_effect = paramiko.SSHException("Connection failed") + + result = _can_connect(TEST_HOST) + + assert result is False + mock_client.load_system_host_keys.assert_called_once() + mock_client.set_missing_host_key_policy.assert_called_once() + mock_client.connect.assert_called_once_with(TEST_HOST, port=22) + + +@patch("subprocess.run") +@patch("sagemaker.modules.train.container_drivers.distributed_drivers.mpi_utils.logger") +def test_write_status_file_to_workers_failure(mock_logger, mock_run): + """Test failed status file writing to workers with retry timeout.""" + mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") + + with pytest.raises(TimeoutError) as exc_info: + write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE) + + assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value) + assert mock_run.call_count > 1 # Verifies that retries occurred + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py new file mode 100644 index 0000000000..2568346158 --- /dev/null +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py @@ -0,0 +1,150 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Torchrun Driver Unit Tests.""" +from __future__ import absolute_import + +import os +import sys +import json + +from unittest.mock import patch, MagicMock + +sys.modules["utils"] = MagicMock() + +from sagemaker.modules.train.container_drivers.distributed_drivers import ( # noqa: E402 + torchrun_driver, +) + +DUMMY_DISTRIBUTED = {"process_count_per_node": 2} + + +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_python_executable", + return_value="python3", +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), +) +def test_get_base_pytorch_command_torchrun(mock_pytorch_version, mock_get_python_executable): + assert torchrun_driver.get_base_pytorch_command() == ["torchrun"] + + +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_python_executable", + return_value="python3", +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(1, 8), +) +def test_get_base_pytorch_command_torch_distributed_launch( + mock_pytorch_version, mock_get_python_executable +): + assert torchrun_driver.get_base_pytorch_command() == ( + ["python3", "-m", "torch.distributed.launch"] + ) + + +@patch.dict( + os.environ, + { + "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge", + "SM_NETWORK_INTERFACE_NAME": "eth0", + "SM_HOST_COUNT": "1", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", + }, +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_process_count", + return_value=2, +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_base_pytorch_command", + return_value=["torchrun"], +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.hyperparameters_to_cli_args", + return_value=[], +) +def test_create_commands_single_node( + mock_hyperparameters_to_cli_args, + mock_get_base_pytorch_command, + mock_pytorch_version, + mock_get_process_count, +): + expected_command = [ + "torchrun", + "--nnodes=1", + "--nproc_per_node=2", + "script.py", + ] + + command = torchrun_driver.create_commands() + assert command == expected_command + + +@patch.dict( + os.environ, + { + "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge", + "SM_NETWORK_INTERFACE_NAME": "eth0", + "SM_HOST_COUNT": "2", + "SM_MASTER_ADDR": "algo-1", + "SM_MASTER_PORT": "7777", + "SM_CURRENT_HOST_RANK": "0", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", + }, +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_process_count", + return_value=2, +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.get_base_pytorch_command", + return_value=["torchrun"], +) +@patch( + "sagemaker.modules.train.container_drivers.distributed_drivers.torchrun_driver.hyperparameters_to_cli_args", + return_value=[], +) +def test_create_commands_multi_node( + mock_hyperparameters_to_cli_args, + mock_get_base_pytorch_command, + mock_pytorch_version, + mock_get_process_count, +): + expected_command = [ + "torchrun", + "--nnodes=2", + "--nproc_per_node=2", + "--master_addr=algo-1", + "--master_port=7777", + "--node_rank=0", + "script.py", + ] + + command = torchrun_driver.create_commands() + assert command == expected_command diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py new file mode 100644 index 0000000000..c563e0607f --- /dev/null +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py @@ -0,0 +1,144 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Container Utils Unit Tests.""" +from __future__ import absolute_import +import os + +from sagemaker.modules.train.container_drivers.common.utils import ( + safe_deserialize, + safe_serialize, + hyperparameters_to_cli_args, + get_process_count, +) + +SM_HPS = { + "boolean": "true", + "dict": '{"string":"value","integer":3,"list":[1,2,3],"dict":{"key":"value"},"boolean":true}', + "float": "3.14", + "integer": "1", + "list": "[1,2,3]", + "string": "Hello World", +} + + +def test_hyperparameters_to_cli_args(): + args = hyperparameters_to_cli_args(SM_HPS) + + assert args == [ + "--boolean", + "true", + "--dict", + '{"string": "value", "integer": 3, "list": [1, 2, 3], "dict": {"key": "value"}, "boolean": true}', + "--float", + "3.14", + "--integer", + "1", + "--list", + "[1, 2, 3]", + "--string", + "Hello World", + ] + + +def test_safe_deserialize_not_a_string(): + assert safe_deserialize(123) == 123 + assert safe_deserialize([1, 2, 3]) == [1, 2, 3] + assert safe_deserialize({"key": "value"}) == {"key": "value"} + + +def test_safe_deserialize_boolean_strings(): + assert safe_deserialize("true") is True + assert safe_deserialize("false") is False + + # The below are not valid JSON booleans + assert safe_deserialize("True") == "True" + assert safe_deserialize("False") == "False" + assert safe_deserialize("TRUE") == "TRUE" + assert safe_deserialize("FALSE") == "FALSE" + assert safe_deserialize("tRuE") == "tRuE" + assert safe_deserialize("fAlSe") == "fAlSe" + + +def test_safe_deserialize_valid_json_string(): + json_data = '{"key": "value", "number": 123, "boolean": true}' + expected_output = {"key": "value", "number": 123, "boolean": True} + assert safe_deserialize(json_data) == expected_output + + assert safe_deserialize("Hello World") == "Hello World" + assert safe_deserialize("12345") == 12345 + + assert safe_deserialize("3.14") == 3.14 + assert safe_deserialize("[1,2,3]") == [1, 2, 3] + + +def test_safe_deserialize_invalid_json_string(): + invalid_json = '{"key": value}' # Missing quotes around value so not valid json + assert safe_deserialize(invalid_json) == invalid_json + + +def test_safe_deserialize_null_string(): + assert safe_deserialize("null") == None # noqa: E711 + assert safe_deserialize("None") == "None" + + +def test_safe_serialize_string(): + assert safe_serialize("Hello World") == "Hello World" + assert safe_serialize("12345") == "12345" + assert safe_serialize("true") == "true" + + +def test_safe_serialize_serializable_data(): + assert safe_serialize({"key": "value", "number": 123, "boolean": True}) == ( + '{"key": "value", "number": 123, "boolean": true}' + ) + assert safe_serialize([1, 2, 3]) == "[1, 2, 3]" + assert safe_serialize(123) == "123" + assert safe_serialize(3.14) == "3.14" + assert safe_serialize(True) == "true" + assert safe_serialize(False) == "false" + assert safe_serialize(None) == "null" + + +def test_safe_serialize_custom_object(): + class CustomObject: + def __str__(self): + return "CustomObject" + + obj = CustomObject() + assert safe_serialize(obj) == "CustomObject" + + +def test_safe_serialize_invalid_data(): + invalid_data = {"key": set([1, 2, 3])} # Sets are not JSON serializable + assert safe_serialize(invalid_data) == str(invalid_data) + + +def test_safe_serialize_empty_data(): + assert safe_serialize("") == "" + assert safe_serialize([]) == "[]" + assert safe_serialize({}) == "{}" + + +def test_get_process_count(): + assert get_process_count() == 1 + assert get_process_count(2) == 2 + os.environ["SM_NUM_GPUS"] = "4" + assert get_process_count() == 4 + os.environ["SM_NUM_GPUS"] = "0" + os.environ["SM_NUM_NEURONS"] = "8" + assert get_process_count() == 8 + os.environ["SM_NUM_NEURONS"] = "0" + assert get_process_count() == 1 + del os.environ["SM_NUM_GPUS"] + del os.environ["SM_NUM_NEURONS"] + assert get_process_count() == 1 diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/__init__.py b/tests/unit/sagemaker/modules/train/sm_recipes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py new file mode 100644 index 0000000000..17cfda55b0 --- /dev/null +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -0,0 +1,448 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utility functions for SageMaker training recipes Tests.""" +from __future__ import absolute_import + +import pytest +from unittest.mock import patch, MagicMock + +import yaml +from omegaconf import OmegaConf +from urllib.request import urlretrieve +from tempfile import NamedTemporaryFile + +from sagemaker.modules.train.sm_recipes.utils import ( + _load_base_recipe, + _get_args_from_recipe, + _load_recipes_cfg, + _configure_gpu_args, + _configure_trainium_args, + _get_trainining_recipe_gpu_model_name_and_script, + _is_nova_recipe, + _get_args_from_nova_recipe, +) +from sagemaker.modules.utils import _run_clone_command_silent +from sagemaker.modules.configs import Compute + + +@pytest.fixture(scope="module") +def training_recipes_cfg(): + return _load_recipes_cfg() + + +@pytest.fixture(scope="module") +def temporary_recipe(): + data = { + "trainer": {"num_nodes": 2, "max_epochs": 10}, + "model": {"model_type": "llama_v3", "num_classes": 10, "num_layers": 10}, + } + with NamedTemporaryFile(suffix=".yaml", delete=False) as f: + with open(f.name, "w") as file: + yaml.dump(data, file) + yield f.name + + +def test_load_base_recipe_with_overrides(temporary_recipe, training_recipes_cfg): + expected_epochs = 20 + expected_layers = 15 + + recipe_overrides = { + "trainer": {"max_epochs": expected_epochs}, + "model": {"num_layers": expected_layers}, + } + + load_recipe = _load_base_recipe( + training_recipe=temporary_recipe, + recipe_overrides=recipe_overrides, + training_recipes_cfg=training_recipes_cfg, + ) + + assert ( + load_recipe["trainer"]["max_epochs"] == expected_epochs + and load_recipe["model"]["num_layers"] == expected_layers + ) + + +@pytest.mark.parametrize( + "test_case", + [ + {"recipe_type": "local"}, + {"recipe_type": "sagemaker"}, + {"recipe_type": "url"}, + {"recipe_type": "not_found"}, + ], +) +@patch("sagemaker.modules.train.sm_recipes.utils.urlretrieve") +@patch("sagemaker.modules.train.sm_recipes.utils._run_clone_command_silent") +def test_load_base_recipe_types( + mock_clone, mock_retrieve, temporary_recipe, training_recipes_cfg, test_case +): + recipe_type = test_case["recipe_type"] + + if recipe_type == "not_found": + with pytest.raises(ValueError): + _load_base_recipe( + training_recipe="not_found", + recipe_overrides=None, + training_recipes_cfg=training_recipes_cfg, + ) + + if recipe_type == "local": + load_recipe = _load_base_recipe( + training_recipe=temporary_recipe, + recipe_overrides=None, + training_recipes_cfg=training_recipes_cfg, + ) + assert load_recipe is not None + assert "trainer" in load_recipe + + if recipe_type == "sagemaker": + mock_clone.side_effect = _run_clone_command_silent + load_recipe = _load_base_recipe( + training_recipe="training/llama/p4_hf_llama3_70b_seq8k_gpu", + recipe_overrides=None, + training_recipes_cfg=training_recipes_cfg, + ) + assert load_recipe is not None + assert "trainer" in load_recipe + assert mock_clone.call_args.args[0] == training_recipes_cfg.get("launcher_repo") + + if recipe_type == "url": + url = "https://raw.githubusercontent.com/aws-neuron/neuronx-distributed-training/refs/heads/main/examples/conf/hf_llama3_8B_config.yaml" # noqa + mock_retrieve.side_effect = urlretrieve + load_recipe = _load_base_recipe( + training_recipe=url, + recipe_overrides=None, + training_recipes_cfg=training_recipes_cfg, + ) + assert load_recipe is not None + assert "trainer" in load_recipe + assert mock_retrieve.call_args.args[0] == url + + +@pytest.mark.parametrize( + "test_case", + [ + {"type": "gpu", "instance_type": "ml.p4d.24xlarge"}, + {"type": "trn", "instance_type": "ml.trn1.32xlarge"}, + {"type": "cpu", "instance_type": "ml.c5.4xlarge"}, + ], +) +@patch("sagemaker.modules.train.sm_recipes.utils._configure_gpu_args") +@patch("sagemaker.modules.train.sm_recipes.utils._configure_trainium_args") +def test_get_args_from_recipe_compute( + mock_trainium_args, mock_gpu_args, temporary_recipe, test_case +): + compute = Compute(instance_type=test_case["instance_type"]) + if test_case["type"] == "gpu": + mock_gpu_args.side_effect = _configure_gpu_args + + args = _get_args_from_recipe( + training_recipe=temporary_recipe, + compute=compute, + region_name="us-west-2", + recipe_overrides=None, + requirements=None, + ) + assert mock_gpu_args.call_count == 1 + assert mock_trainium_args.call_count == 0 + + if test_case["type"] == "trn": + mock_trainium_args.side_effect = _configure_trainium_args + + args = _get_args_from_recipe( + training_recipe=temporary_recipe, + compute=compute, + region_name="us-west-2", + recipe_overrides=None, + requirements=None, + ) + assert mock_gpu_args.call_count == 0 + assert mock_trainium_args.call_count == 1 + + if test_case["type"] == "cpu": + with pytest.raises(ValueError): + args = _get_args_from_recipe( + training_recipe=temporary_recipe, + compute=compute, + region_name="us-west-2", + recipe_overrides=None, + requirements=None, + ) + assert mock_gpu_args.call_count == 0 + assert mock_trainium_args.call_count == 0 + assert args is None + + +@patch("sagemaker.modules.train.sm_recipes.utils._get_args_from_nova_recipe") +def test_get_args_from_recipe_with_nova_and_role(mock_get_args_from_nova_recipe, temporary_recipe): + # Set up mock return value + mock_args = {"hyperparameters": {}} + mock_dir = MagicMock() + mock_get_args_from_nova_recipe.return_value = (mock_args, mock_dir) + + # Create a Nova recipe with distillation data + recipe = OmegaConf.create( + {"training_config": {"distillation_data": True, "kms_key": "alias/my-kms-key"}} + ) + compute = Compute(instance_type="ml.g5.xlarge") + role = "arn:aws:iam::123456789012:role/SageMakerRole" + + # Mock the Nova recipe detection to return True + with patch("sagemaker.modules.train.sm_recipes.utils._is_nova_recipe", return_value=True): + _get_args_from_recipe( + training_recipe=recipe, + compute=compute, + region_name="us-west-2", + recipe_overrides=None, + requirements=None, + role=role, + ) + + # Verify _get_args_from_nova_recipe was called with the role parameter + mock_get_args_from_nova_recipe.assert_called_once_with(recipe, compute, role=role) + + +@pytest.mark.parametrize( + "test_case", + [ + {"model_type": "llama_v4", "script": "llama_pretrain.py", "model_base_name": "llama"}, + { + "model_type": "llama_v3", + "script": "llama_pretrain.py", + "model_base_name": "llama", + }, + { + "model_type": "mistral", + "script": "mistral_pretrain.py", + "model_base_name": "mistral", + }, + { + "model_type": "deepseek_llamav3", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + { + "model_type": "deepseek_qwenv2", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + { + "model_type": "gpt_oss", + "script": "custom_pretrain.py", + "model_base_name": "custom_model", + }, + ], +) +def test_get_trainining_recipe_gpu_model_name_and_script(test_case): + model_type = test_case["model_type"] + script = test_case["script"] + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) + assert model_base_name == test_case["model_base_name"] + assert script == test_case["script"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "some-model", + } + }, + "is_nova": True, + }, + { + "recipe": { + "run": { + "model_type": "amazon.nova.other", + "model_name_or_path": "some-model", + } + }, + "is_nova": True, + }, + {"recipe": {"run": {"model_type": "amazon.nova.other"}}, "is_nova": False}, + { + "recipe": {"run": {"model_type": "other.model", "model_name_or_path": "some-model"}}, + "is_nova": False, + }, + { + "recipe": {"training_config": {"distillation_data": "s3://bucket/distillation-data"}}, + "is_nova": True, + }, + { + "recipe": {"training_config": {"some_other_field": "value"}}, + "is_nova": False, + }, + ], + ids=[ + "nova_model", + "nova_model_subtype", + "nova_missing_model_path", + "non_nova_model", + "distillation_data", + "no_distillation_data", + ], +) +def test_is_nova_recipe(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + is_nova = _is_nova_recipe(recipe) + assert is_nova == test_case["is_nova"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "run": {"model_type": "amazon.nova", "model_name_or_path": "dummy-test"}, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "hyperparameters": {"base_model": "dummy-test"}, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "s3://bucket/path/to/model", + }, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "hyperparameters": {"base_model_location": "s3://bucket/path/to/model"}, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "s3://bucket/path/to/model", + "replicas": 4, + }, + }, + "compute": Compute(instance_type="ml.m5.xlarge"), + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=4), + "hyperparameters": {"base_model_location": "s3://bucket/path/to/model"}, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + { + "recipe": { + "run": { + "model_type": "amazon.nova", + "model_name_or_path": "s3://bucket/path/to/model", + "replicas": 2, + }, + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=4), + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=4), + "hyperparameters": {"base_model_location": "s3://bucket/path/to/model"}, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + ], +) +def test_get_args_from_nova_recipe(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + args, _ = _get_args_from_nova_recipe(recipe=recipe, compute=test_case["compute"]) + assert args == test_case["expected_args"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "training_config": { + "distillation_data": "s3://bucket/distillation-data", + "kms_key": "alias/my-kms-key", + } + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "role": "arn:aws:iam::123456789012:role/SageMakerRole", + "expected_args": { + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "hyperparameters": { + "distillation_data": "s3://bucket/distillation-data", + "role_arn": "arn:aws:iam::123456789012:role/SageMakerRole", + "kms_key": "alias/my-kms-key", + }, + "training_image": None, + "source_code": None, + "distributed": None, + }, + }, + ], +) +def test_get_args_from_nova_recipe_with_distillation(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + args, _ = _get_args_from_nova_recipe( + recipe=recipe, compute=test_case["compute"], role=test_case["role"] + ) + assert args == test_case["expected_args"] + + +@pytest.mark.parametrize( + "test_case", + [ + { + "recipe": { + "training_config": { + "distillation_data": "s3://bucket/distillation-data", + # Missing kms_key + } + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + "role": "arn:aws:iam::123456789012:role/SageMakerRole", + }, + { + "recipe": { + "training_config": { + "distillation_data": "s3://bucket/distillation-data", + "kms_key": "alias/my-kms-key", + } + }, + "compute": Compute(instance_type="ml.m5.xlarge", instance_count=2), + # Missing role + "role": None, + }, + ], + ids=[ + "missing_kms_key", + "missing_role", + ], +) +def test_get_args_from_nova_recipe_with_distillation_errors(test_case): + recipe = OmegaConf.create(test_case["recipe"]) + with pytest.raises(ValueError): + _get_args_from_nova_recipe( + recipe=recipe, compute=test_case["compute"], role=test_case.get("role") + ) diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py new file mode 100644 index 0000000000..73893ea7f4 --- /dev/null +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -0,0 +1,1478 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""ModelTrainer Tests.""" +from __future__ import absolute_import + +import shutil +import tempfile +import json +import os +import yaml +import pytest +from pydantic import ValidationError +from unittest.mock import patch, MagicMock, ANY, mock_open +from tempfile import NamedTemporaryFile + +from sagemaker import image_uris +from sagemaker_core.main.resources import TrainingJob +from sagemaker_core.main.shapes import ( + ResourceConfig, + VpcConfig, + AlgorithmSpecification, +) + +from sagemaker.config import SAGEMAKER, PYTHON_SDK, MODULES +from sagemaker.config.config_schema import ( + MODEL_TRAINER, + _simple_path, + TRAINING_JOB_RESOURCE_CONFIG_PATH, +) +from sagemaker.modules import Session +from sagemaker.modules.train.model_trainer import ModelTrainer, Mode +from sagemaker.modules.constants import ( + DEFAULT_INSTANCE_TYPE, + DISTRIBUTED_JSON, + SOURCE_CODE_JSON, + TRAIN_SCRIPT, + SM_RECIPE_CONTAINER_PATH, +) +from sagemaker.modules.configs import ( + Compute, + StoppingCondition, + RetryStrategy, + OutputDataConfig, + SourceCode, + RemoteDebugConfig, + TensorBoardOutputConfig, + InfraCheckConfig, + SessionChainingConfig, + InputData, + Networking, + TrainingImageConfig, + TrainingRepositoryAuthConfig, + CheckpointConfig, + Tag, + S3DataSource, + FileSystemDataSource, + Channel, + DataSource, + MetricDefinition, +) +from sagemaker.modules.distributed import Torchrun, SMP, MPI +from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg +from sagemaker.modules.templates import EXEUCTE_DISTRIBUTED_DRIVER +from tests.unit import DATA_DIR + +DEFAULT_BASE_NAME = "dummy-image-job" +DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest" +DEFAULT_BUCKET = "sagemaker-us-west-2-000000000000" +DEFAULT_ROLE = "arn:aws:iam::000000000000:role/test-role" +DEFAULT_BUCKET_PREFIX = "sample-prefix" +DEFAULT_REGION = "us-west-2" +DEFAULT_SOURCE_DIR = f"{DATA_DIR}/modules/script_mode" +DEFAULT_COMPUTE_CONFIG = Compute(instance_type=DEFAULT_INSTANCE_TYPE, instance_count=1) +DEFAULT_OUTPUT_DATA_CONFIG = OutputDataConfig( + s3_output_path=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{DEFAULT_BASE_NAME}", + compression_type="GZIP", + kms_key_id=None, +) +DEFAULT_STOPPING_CONDITION = StoppingCondition( + max_runtime_in_seconds=3600, + max_pending_time_in_seconds=None, + max_wait_time_in_seconds=None, +) +DEFAULT_SOURCE_CODE = SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", +) +DEFAULT_ENTRYPOINT = ["/bin/bash"] +DEFAULT_ARGUMENTS = [ + "-c", + ( + "chmod +x /opt/ml/input/data/sm_drivers/sm_train.sh " + + "&& /opt/ml/input/data/sm_drivers/sm_train.sh" + ), +] + + +@pytest.fixture(scope="module", autouse=True) +def modules_session(): + with patch("sagemaker.modules.Session", spec=Session) as session_mock: + session_instance = session_mock.return_value + session_instance.default_bucket.return_value = DEFAULT_BUCKET + session_instance.get_caller_identity_arn.return_value = DEFAULT_ROLE + session_instance.default_bucket_prefix = DEFAULT_BUCKET_PREFIX + session_instance.boto_session = MagicMock(spec="boto3.session.Session") + session_instance.boto_region_name = DEFAULT_REGION + yield session_instance + + +@pytest.fixture +def model_trainer(): + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + compute=DEFAULT_COMPUTE_CONFIG, + stopping_condition=DEFAULT_STOPPING_CONDITION, + output_data_config=DEFAULT_OUTPUT_DATA_CONFIG, + ) + return trainer + + +@pytest.mark.parametrize( + "test_case", + [ + { + "init_params": {}, + "should_throw": True, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + }, + "should_throw": False, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "algorithm_name": "dummy-arn", + }, + "should_throw": True, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": SourceCode( + entry_script="train.py", + ), + }, + "should_throw": True, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": SourceCode( + source_dir="s3://bucket/requirements.txt", + entry_script="custom_script.py", + ), + }, + "should_throw": True, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": DEFAULT_SOURCE_CODE, + }, + "should_throw": False, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": SourceCode( + source_dir=f"{DEFAULT_SOURCE_DIR}/code.tar.gz", + entry_script="custom_script.py", + ), + }, + "should_throw": False, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": SourceCode( + source_dir="s3://bucket/code/", + entry_script="custom_script.py", + ), + }, + "should_throw": False, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": SourceCode( + source_dir="s3://bucket/code/code.tar.gz", + entry_script="custom_script.py", + ), + }, + "should_throw": False, + }, + { + "init_params": { + "training_image": DEFAULT_IMAGE, + "source_code": SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + command="python custom_script.py", + ignore_patterns=["data"], + ), + }, + "should_throw": False, + }, + ], + ids=[ + "no_params", + "training_image_and_algorithm_name", + "only_training_image", + "unsupported_source_code_missing_source_dir", + "unsupported_source_code_s3_other_file", + "supported_source_code_local_dir", + "supported_source_code_local_tar_file", + "supported_source_code_s3_dir", + "supported_source_code_s3_tar_file", + "supported_source_code_ignore_patterns", + ], +) +def test_model_trainer_param_validation(test_case, modules_session): + if test_case["should_throw"]: + with pytest.raises(ValueError): + ModelTrainer(**test_case["init_params"], sagemaker_session=modules_session) + else: + trainer = ModelTrainer(**test_case["init_params"], sagemaker_session=modules_session) + assert trainer is not None + assert trainer.training_image == DEFAULT_IMAGE + assert trainer.compute == DEFAULT_COMPUTE_CONFIG + assert trainer.output_data_config == DEFAULT_OUTPUT_DATA_CONFIG + assert trainer.stopping_condition == DEFAULT_STOPPING_CONDITION + assert trainer.base_job_name == DEFAULT_BASE_NAME + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_train_with_default_params(mock_training_job, model_trainer): + model_trainer.train() + + mock_training_job.create.assert_called_once() + + training_job_instance = mock_training_job.create.return_value + training_job_instance.wait.assert_called_once_with(logs=True) + + +@pytest.mark.parametrize( + "default_config", + [ + { + "path_name": "sourceCode", + "config_value": {"command": "echo 'Hello World' && env"}, + }, + { + "path_name": "compute", + "config_value": {"volume_size_in_gb": 45}, + }, + { + "path_name": "networking", + "config_value": { + "enable_network_isolation": True, + "security_group_ids": ["sg-1"], + "subnets": ["subnet-1"], + }, + }, + { + "path_name": "stoppingCondition", + "config_value": {"max_runtime_in_seconds": 15}, + }, + { + "path_name": "trainingImageConfig", + "config_value": {"training_repository_access_mode": "private"}, + }, + { + "path_name": "outputDataConfig", + "config_value": {"s3_output_path": "Sample S3 path"}, + }, + { + "path_name": "checkpointConfig", + "config_value": {"s3_uri": "sample uri"}, + }, + ], +) +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +@patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") +@patch("sagemaker.modules.train.model_trainer.ModelTrainer.create_input_data_channel") +def test_train_with_intelligent_defaults( + mock_create_input_data_channel, + mock_resolve_value_from_config, + mock_training_job, + default_config, + model_trainer, +): + mock_resolve_value_from_config.side_effect = lambda **kwargs: ( + default_config["config_value"] + if kwargs["config_path"] + == _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, MODEL_TRAINER, default_config["path_name"]) + else None + ) + + model_trainer.train() + + mock_training_job.create.assert_called_once() + + training_job_instance = mock_training_job.create.return_value + training_job_instance.wait.assert_called_once_with(logs=True) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +@patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") +def test_train_with_intelligent_defaults_training_job_space( + mock_resolve_value_from_config, mock_training_job, model_trainer +): + mock_resolve_value_from_config.side_effect = lambda **kwargs: ( + { + "instanceType": DEFAULT_INSTANCE_TYPE, + "instanceCount": 1, + "volumeSizeInGB": 30, + } + if kwargs["config_path"] == TRAINING_JOB_RESOURCE_CONFIG_PATH + else None + ) + + model_trainer.train() + + mock_training_job.create.assert_called_once_with( + training_job_name=ANY, + algorithm_specification=ANY, + hyper_parameters={}, + input_data_config=[], + resource_config=ResourceConfig( + volume_size_in_gb=30, instance_type="ml.m5.xlarge", instance_count=1 + ), + vpc_config=None, + session=ANY, + role_arn="arn:aws:iam::000000000000:" "role/test-role", + tags=None, + stopping_condition=StoppingCondition( + max_runtime_in_seconds=3600, + max_wait_time_in_seconds=None, + max_pending_time_in_seconds=None, + ), + output_data_config=OutputDataConfig( + s3_output_path="s3://" + "sagemaker-us-west-2" + "-000000000000/" + "sample-prefix/" + "dummy-image-job", + kms_key_id=None, + compression_type="GZIP", + ), + checkpoint_config=None, + environment=None, + enable_managed_spot_training=None, + enable_inter_container_traffic_encryption=None, + enable_network_isolation=None, + remote_debug_config=None, + tensor_board_output_config=None, + retry_strategy=None, + infra_check_config=None, + session_chaining_config=None, + ) + + training_job_instance = mock_training_job.create.return_value + training_job_instance.wait.assert_called_once_with(logs=True) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +@patch.object(ModelTrainer, "_get_input_data_config") +def test_train_with_input_data_channels(mock_get_input_config, mock_training_job, model_trainer): + train_data = InputData(channel_name="train", data_source="train/dir") + test_data = InputData(channel_name="test", data_source="test/dir") + mock_input_data_config = [train_data, test_data] + + model_trainer.train(input_data_config=mock_input_data_config) + + mock_get_input_config.assert_called_once_with(mock_input_data_config, ANY) + mock_training_job.create.assert_called_once() + + +@pytest.mark.parametrize( + "test_case", + [ + { + "channel_name": "test", + "data_source": DATA_DIR, + "valid": True, + }, + { + "channel_name": "test", + "data_source": f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test", + "valid": True, + }, + { + "channel_name": "test", + "data_source": S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test", + s3_data_distribution_type="FullyReplicated", + ), + "valid": True, + }, + { + "channel_name": "test", + "data_source": FileSystemDataSource( + file_system_id="fs-000000000000", + file_system_access_mode="ro", + file_system_type="EFS", + directory_path="/data/test", + ), + "valid": True, + }, + { + "channel_name": "test", + "data_source": "fake/path", + "valid": False, + }, + ], + ids=[ + "valid_local_path", + "valid_s3_path", + "valid_s3_data_source", + "valid_file_system_data_source", + "invalid_path", + ], +) +@patch("sagemaker.modules.train.model_trainer.Session.upload_data") +@patch("sagemaker.modules.train.model_trainer.Session.default_bucket") +def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_trainer, test_case): + expected_s3_uri = f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test" + mock_upload_data.return_value = expected_s3_uri + mock_default_bucket.return_value = DEFAULT_BUCKET + if not test_case["valid"]: + with pytest.raises(ValueError): + model_trainer.create_input_data_channel( + test_case["channel_name"], test_case["data_source"] + ) + else: + channel = model_trainer.create_input_data_channel( + test_case["channel_name"], test_case["data_source"] + ) + assert channel.channel_name == test_case["channel_name"] + if isinstance(test_case["data_source"], S3DataSource): + assert channel.data_source.s3_data_source == test_case["data_source"] + elif isinstance(test_case["data_source"], FileSystemDataSource): + assert channel.data_source.file_system_data_source == test_case["data_source"] + else: + assert channel.data_source.s3_data_source.s3_uri == expected_s3_uri + + +@pytest.mark.parametrize( + "test_case", + [ + { + "source_code": DEFAULT_SOURCE_CODE, + "distributed": Torchrun(), + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="Torchrun", driver_script="torchrun_driver.py" + ), + "expected_hyperparameters": {}, + }, + { + "source_code": DEFAULT_SOURCE_CODE, + "distributed": Torchrun( + smp=SMP( + hybrid_shard_degree=3, + sm_activation_offloading=True, + allow_empty_shards=True, + tensor_parallel_degree=5, + ) + ), + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="Torchrun", driver_script="torchrun_driver.py" + ), + "expected_hyperparameters": { + "mp_parameters": json.dumps( + { + "hybrid_shard_degree": 3, + "sm_activation_offloading": True, + "allow_empty_shards": True, + "tensor_parallel_degree": 5, + } + ), + }, + }, + { + "source_code": DEFAULT_SOURCE_CODE, + "distributed": MPI( + mpi_additional_options=["-x", "VAR1", "-x", "VAR2"], + ), + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="MPI", driver_script="mpi_driver.py" + ), + "expected_hyperparameters": {}, + }, + ], + ids=[ + "torchrun", + "torchrun_smp", + "mpi", + ], +) +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory") +@patch("sagemaker.modules.train.model_trainer.resolve_value_from_config") +def test_train_with_distributed_config( + mock_resolve_value_from_config, + mock_tmp_dir, + mock_training_job, + test_case, + request, + modules_session, +): + mock_resolve_value_from_config.return_value = None + modules_session.upload_data.return_value = ( + f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test" + ) + + tmp_dir = tempfile.TemporaryDirectory() + tmp_dir._cleanup = False + tmp_dir.cleanup = lambda: None + mock_tmp_dir.return_value = tmp_dir + + expected_train_script_path = os.path.join(tmp_dir.name, TRAIN_SCRIPT) + expected_runner_json_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON) + expected_source_code_json_path = os.path.join(tmp_dir.name, SOURCE_CODE_JSON) + + try: + model_trainer = ModelTrainer( + sagemaker_session=modules_session, + training_image=DEFAULT_IMAGE, + source_code=test_case["source_code"], + distributed=test_case["distributed"], + ) + + model_trainer.train() + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["hyper_parameters"] == ( + test_case["expected_hyperparameters"] + ) + + assert os.path.exists(expected_train_script_path) + with open(expected_train_script_path, "r") as f: + train_script_content = f.read() + assert test_case["expected_template"] in train_script_content + + assert os.path.exists(expected_runner_json_path) + with open(expected_runner_json_path, "r") as f: + runner_json_content = f.read() + assert test_case["distributed"].model_dump() == (json.loads(runner_json_content)) + assert os.path.exists(expected_source_code_json_path) + with open(expected_source_code_json_path, "r") as f: + source_code_json_content = f.read() + assert test_case["source_code"].model_dump() == (json.loads(source_code_json_content)) + assert os.path.exists(expected_source_code_json_path) + with open(expected_source_code_json_path, "r") as f: + source_code_json_content = f.read() + assert test_case["source_code"].model_dump() == (json.loads(source_code_json_content)) + finally: + shutil.rmtree(tmp_dir.name) + assert not os.path.exists(tmp_dir.name) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_train_stores_created_training_job(mock_training_job, model_trainer): + mock_training_job.create.return_value = TrainingJob(training_job_name="Created-job") + model_trainer.train(wait=False) + assert model_trainer._latest_training_job is not None + assert model_trainer._latest_training_job == TrainingJob(training_job_name="Created-job") + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_tensorboard_output_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + tensorboard_output_config = TensorBoardOutputConfig( + s3_output_path=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}", + local_path="/opt/ml/output/tensorboard", + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_tensorboard_output_config(tensorboard_output_config) + + assert model_trainer._tensorboard_output_config == tensorboard_output_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["tensor_board_output_config"] + == tensorboard_output_config + ) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_retry_strategy(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + retry_strategy = RetryStrategy( + maximum_retry_attempts=3, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_retry_strategy(retry_strategy) + + assert model_trainer._retry_strategy == retry_strategy + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["retry_strategy"] == retry_strategy + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_infra_check_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + infra_check_config = InfraCheckConfig( + enable_infra_check=True, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_infra_check_config(infra_check_config) + + assert model_trainer._infra_check_config == infra_check_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["infra_check_config"] == infra_check_config + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_session_chaining_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + session_chaining_config = SessionChainingConfig( + enable_session_tag_chaining=True, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_session_chaining_config(session_chaining_config) + + assert model_trainer._session_chaining_config == session_chaining_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["session_chaining_config"] + == session_chaining_config + ) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_remote_debug_config(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + remote_debug_config = RemoteDebugConfig( + enable_remote_debug=True, + ) + + model_trainer = ModelTrainer( + training_image=image_uri, + sagemaker_session=modules_session, + role=role, + ).with_remote_debug_config(remote_debug_config) + + assert model_trainer._remote_debug_config == remote_debug_config + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["remote_debug_config"] == remote_debug_config + ) + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_metric_definitions(mock_training_job, modules_session): + image_uri = DEFAULT_IMAGE + role = DEFAULT_ROLE + metric_definitions = [ + MetricDefinition( + name="loss", + regex="Loss: (.*?);", + ) + ] + + model_trainer = ModelTrainer( + training_image=image_uri, sagemaker_session=modules_session, role=role + ).with_metric_definitions(metric_definitions) + + with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data: + mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix" + model_trainer.train() + + mock_training_job.create.assert_called_once() + assert ( + mock_training_job.create.call_args.kwargs["algorithm_specification"].metric_definitions + == metric_definitions + ) + + +@patch("sagemaker.modules.train.model_trainer._get_unique_name") +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_model_trainer_full_init(mock_training_job, mock_unique_name, modules_session): + def mock_upload_data(path, bucket, key_prefix): + return f"s3://{bucket}/{key_prefix}" + + modules_session.upload_data.side_effect = mock_upload_data + + training_mode = Mode.SAGEMAKER_TRAINING_JOB + role = DEFAULT_ROLE + source_code = DEFAULT_SOURCE_CODE + distributed = Torchrun() + compute = Compute( + instance_type=DEFAULT_INSTANCE_TYPE, + instance_count=1, + volume_size_in_gb=30, + volume_kms_key_id="key-id", + keep_alive_period_in_seconds=3600, + enable_managed_spot_training=True, + ) + networking = Networking( + security_group_ids=["sg-000000000000"], + subnets=["subnet-000000000000"], + enable_network_isolation=True, + enable_inter_container_traffic_encryption=True, + ) + stopping_condition = DEFAULT_STOPPING_CONDITION + training_image = DEFAULT_IMAGE + training_image_config = TrainingImageConfig( + training_repository_access_mode="Platform", + training_repository_auth_config=TrainingRepositoryAuthConfig( + training_repository_credentials_provider_arn="arn:aws:lambda:us-west-2:000000000000:function:dummy-function" + ), + ) + output_data_config = DEFAULT_OUTPUT_DATA_CONFIG + + local_input_data = InputData( + channel_name="train", data_source=f"{DEFAULT_SOURCE_DIR}/data/train" + ) + s3_data_source_input = InputData( + channel_name="test", + data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}/data/test", + s3_data_distribution_type="FullyReplicated", + attribute_names=["label"], + instance_group_names=["instance-group"], + ), + ) + file_system_input = InputData( + channel_name="validation", + data_source=FileSystemDataSource( + file_system_id="fs-000000000000", + file_system_access_mode="ro", + file_system_type="EFS", + directory_path="/data/validation", + ), + ) + input_data_config = [local_input_data, s3_data_source_input, file_system_input] + checkpoint_config = CheckpointConfig( + local_path="/opt/ml/checkpoints", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}/checkpoints", + ) + training_input_mode = "File" + environment = {"ENV_VAR": "value"} + hyperparameters = {"key": "value"} + tags = [Tag(key="key", value="value")] + + model_trainer = ModelTrainer( + training_mode=training_mode, + sagemaker_session=modules_session, + role=role, + source_code=source_code, + distributed=distributed, + compute=compute, + networking=networking, + stopping_condition=stopping_condition, + training_image=training_image, + training_image_config=training_image_config, + output_data_config=output_data_config, + input_data_config=input_data_config, + checkpoint_config=checkpoint_config, + training_input_mode=training_input_mode, + environment=environment, + hyperparameters=hyperparameters, + tags=tags, + ) + + assert model_trainer.training_mode == training_mode + assert model_trainer.sagemaker_session == modules_session + assert model_trainer.role == role + assert model_trainer.source_code == source_code + assert model_trainer.distributed == distributed + assert model_trainer.compute == compute + assert model_trainer.networking == networking + assert model_trainer.stopping_condition == stopping_condition + assert model_trainer.training_image == training_image + assert model_trainer.training_image_config == training_image_config + assert model_trainer.output_data_config == output_data_config + assert model_trainer.input_data_config == input_data_config + assert model_trainer.checkpoint_config == checkpoint_config + assert model_trainer.training_input_mode == training_input_mode + assert model_trainer.environment == environment + assert model_trainer.hyperparameters == hyperparameters + assert model_trainer.tags == tags + + unique_name = "training-job" + mock_unique_name.return_value = unique_name + + model_trainer.train() + + mock_training_job.create.assert_called_once_with( + training_job_name=unique_name, + algorithm_specification=AlgorithmSpecification( + training_input_mode=training_input_mode, + training_image=training_image, + algorithm_name=None, + metric_definitions=None, + container_entrypoint=DEFAULT_ENTRYPOINT, + container_arguments=DEFAULT_ARGUMENTS, + training_image_config=training_image_config, + ), + hyper_parameters=hyperparameters, + input_data_config=[ + Channel( + channel_name=local_input_data.channel_name, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{DEFAULT_BASE_NAME}/{unique_name}/input/train", # noqa: E501 + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ), + Channel( + channel_name=s3_data_source_input.channel_name, + data_source=DataSource(s3_data_source=s3_data_source_input.data_source), + ), + Channel( + channel_name=file_system_input.channel_name, + data_source=DataSource(file_system_data_source=file_system_input.data_source), + ), + Channel( + channel_name="code", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{DEFAULT_BASE_NAME}/{unique_name}/input/code", # noqa: E501 + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ), + Channel( + channel_name="sm_drivers", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{DEFAULT_BASE_NAME}/{unique_name}/input/sm_drivers", # noqa: E501 + s3_data_distribution_type="FullyReplicated", + ), + ), + input_mode="File", + ), + ], + resource_config=ResourceConfig( + instance_type=compute.instance_type, + instance_count=compute.instance_count, + volume_size_in_gb=compute.volume_size_in_gb, + volume_kms_key_id=compute.volume_kms_key_id, + keep_alive_period_in_seconds=compute.keep_alive_period_in_seconds, + ), + vpc_config=VpcConfig( + security_group_ids=networking.security_group_ids, + subnets=networking.subnets, + ), + session=ANY, + role_arn=role, + tags=tags, + stopping_condition=stopping_condition, + output_data_config=output_data_config, + checkpoint_config=checkpoint_config, + environment=environment, + enable_managed_spot_training=compute.enable_managed_spot_training, + enable_inter_container_traffic_encryption=( + networking.enable_inter_container_traffic_encryption + ), + enable_network_isolation=networking.enable_network_isolation, + remote_debug_config=None, + tensor_board_output_config=None, + retry_strategy=None, + infra_check_config=None, + session_chaining_config=None, + ) + + +def test_model_trainer_gpu_recipe_full_init(modules_session): + training_recipe = "training/llama/p4_hf_llama3_70b_seq8k_gpu" + recipe_overrides = {"run": {"results_dir": "/opt/ml/model"}} + compute = Compute(instance_type="ml.p4d.24xlarge", instance_count="2") + + gpu_image_cfg = _load_recipes_cfg().get("gpu_image") + if isinstance(gpu_image_cfg, str): + expected_training_image = gpu_image_cfg + else: + expected_training_image = image_uris.retrieve( + gpu_image_cfg.get("framework"), + region=modules_session.boto_region_name, + version=gpu_image_cfg.get("version"), + image_scope="training", + **gpu_image_cfg.get("additional_args"), + ) + + expected_distributed = Torchrun(smp=SMP(random_seed=123456)) + expected_hyperparameters = {"config-path": ".", "config-name": "recipe.yaml"} + + networking = Networking( + security_group_ids=["sg-000000000000"], + subnets=["subnet-000000000000"], + enable_network_isolation=True, + enable_inter_container_traffic_encryption=True, + ) + stopping_condition = DEFAULT_STOPPING_CONDITION + output_data_config = DEFAULT_OUTPUT_DATA_CONFIG + local_input_data = InputData( + channel_name="train", data_source=f"{DEFAULT_SOURCE_DIR}/data/train" + ) + input_data_config = [local_input_data] + checkpoint_config = CheckpointConfig( + local_path="/opt/ml/checkpoints", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}/checkpoints", + ) + training_input_mode = "File" + environment = {"ENV_VAR": "value"} + tags = [Tag(key="key", value="value")] + requirements = f"{DEFAULT_SOURCE_DIR}/requirements.txt" + + model_trainer = ModelTrainer.from_recipe( + training_recipe=training_recipe, + recipe_overrides=recipe_overrides, + compute=compute, + networking=networking, + stopping_condition=stopping_condition, + requirements=requirements, + output_data_config=output_data_config, + input_data_config=input_data_config, + checkpoint_config=checkpoint_config, + training_input_mode=training_input_mode, + environment=environment, + tags=tags, + sagemaker_session=modules_session, + role=DEFAULT_ROLE, + base_job_name=DEFAULT_BASE_NAME, + ) + + assert model_trainer.training_image == expected_training_image + assert model_trainer.distributed == expected_distributed + assert model_trainer.hyperparameters == expected_hyperparameters + assert model_trainer.source_code is not None + assert model_trainer.source_code.requirements == "requirements.txt" + + assert model_trainer.compute == compute + assert model_trainer.networking == networking + assert model_trainer.stopping_condition == stopping_condition + assert model_trainer.output_data_config == output_data_config + assert model_trainer.input_data_config == input_data_config + assert model_trainer.checkpoint_config == checkpoint_config + assert model_trainer.training_input_mode == training_input_mode + assert model_trainer.environment == environment + assert model_trainer.tags == tags + + +@patch("sagemaker.modules.train.model_trainer._LocalContainer") +@patch("sagemaker.modules.train.model_trainer._get_unique_name") +@patch("sagemaker.modules.local_core.local_container.download_folder") +def test_model_trainer_local_full_init( + mock_download_folder, mock_unique_name, mock_local_container, modules_session +): + def mock_upload_data(path, bucket, key_prefix): + return f"s3://{bucket}/{key_prefix}" + + modules_session.upload_data.side_effect = mock_upload_data + mock_download_folder.return_value = f"{DEFAULT_SOURCE_DIR}/data/test" + mock_local_container.train.return_value = None + + training_mode = Mode.LOCAL_CONTAINER + role = DEFAULT_ROLE + source_code = DEFAULT_SOURCE_CODE + distributed = Torchrun() + compute = Compute( + instance_type=DEFAULT_INSTANCE_TYPE, + instance_count=1, + volume_size_in_gb=30, + volume_kms_key_id="key-id", + keep_alive_period_in_seconds=3600, + enable_managed_spot_training=True, + ) + networking = Networking( + security_group_ids=["sg-000000000000"], + subnets=["subnet-000000000000"], + enable_network_isolation=True, + enable_inter_container_traffic_encryption=True, + ) + stopping_condition = DEFAULT_STOPPING_CONDITION + training_image = DEFAULT_IMAGE + training_image_config = TrainingImageConfig( + training_repository_access_mode="Platform", + training_repository_auth_config=TrainingRepositoryAuthConfig( + training_repository_credentials_provider_arn="arn:aws:lambda:us-west-2:000000000000:function:dummy-function" + ), + ) + output_data_config = DEFAULT_OUTPUT_DATA_CONFIG + + local_input_data = InputData( + channel_name="train", data_source=f"{DEFAULT_SOURCE_DIR}/data/train" + ) + s3_data_source_input = InputData( + channel_name="test", + data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}/data/test", + s3_data_distribution_type="FullyReplicated", + attribute_names=["label"], + instance_group_names=["instance-group"], + ), + ) + file_system_input = InputData( + channel_name="validation", + data_source=FileSystemDataSource( + file_system_id="fs-000000000000", + file_system_access_mode="ro", + file_system_type="EFS", + directory_path="/data/validation", + ), + ) + input_data_config = [local_input_data, s3_data_source_input, file_system_input] + checkpoint_config = CheckpointConfig( + local_path="/opt/ml/checkpoints", + s3_uri=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}/checkpoints", + ) + training_input_mode = "File" + environment = {"ENV_VAR": "value"} + hyperparameters = {"key": "value"} + tags = [Tag(key="key", value="value")] + + local_container_root = os.getcwd() + + model_trainer = ModelTrainer( + training_mode=training_mode, + sagemaker_session=modules_session, + role=role, + source_code=source_code, + distributed=distributed, + compute=compute, + networking=networking, + stopping_condition=stopping_condition, + training_image=training_image, + training_image_config=training_image_config, + output_data_config=output_data_config, + input_data_config=input_data_config, + checkpoint_config=checkpoint_config, + training_input_mode=training_input_mode, + environment=environment, + hyperparameters=hyperparameters, + tags=tags, + local_container_root=local_container_root, + ) + + assert model_trainer.training_mode == training_mode + assert model_trainer.sagemaker_session == modules_session + assert model_trainer.role == role + assert model_trainer.source_code == source_code + assert model_trainer.distributed == distributed + assert model_trainer.compute == compute + assert model_trainer.networking == networking + assert model_trainer.stopping_condition == stopping_condition + assert model_trainer.training_image == training_image + assert model_trainer.training_image_config == training_image_config + assert model_trainer.output_data_config == output_data_config + assert model_trainer.input_data_config == input_data_config + assert model_trainer.checkpoint_config == checkpoint_config + assert model_trainer.training_input_mode == training_input_mode + assert model_trainer.environment == environment + assert model_trainer.hyperparameters == hyperparameters + assert model_trainer.tags == tags + + unique_name = "training-job" + mock_unique_name.return_value = unique_name + + model_trainer.train() + + mock_local_container.assert_called_once_with( + training_job_name=unique_name, + instance_type=compute.instance_type, + instance_count=compute.instance_count, + image=training_image, + container_root=local_container_root, + sagemaker_session=modules_session, + container_entrypoint=DEFAULT_ENTRYPOINT, + container_arguments=DEFAULT_ARGUMENTS, + input_data_config=ANY, + hyper_parameters=hyperparameters, + environment=environment, + ) + + +def test_safe_configs(): + # Test extra fails + with pytest.raises(ValueError): + SourceCode(entry_point="train.py") + # Test invalid type fails + with pytest.raises(ValueError): + SourceCode(entry_script=1) + + +@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory") +def test_destructor_cleanup(mock_tmp_dir, modules_session): + + with pytest.raises(ValidationError): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute="test", + ) + mock_tmp_dir.cleanup.assert_not_called() + + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + ) + model_trainer._temp_recipe_train_dir = mock_tmp_dir + mock_tmp_dir.assert_not_called() + del model_trainer + mock_tmp_dir.cleanup.assert_called_once() + + +@patch("os.path.exists") +def test_hyperparameters_valid_json(mock_exists, modules_session): + mock_exists.return_value = True + expected_hyperparameters = {"param1": "value1", "param2": 2} + mock_file_open = mock_open(read_data=json.dumps(expected_hyperparameters)) + + with patch("builtins.open", mock_file_open): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.json", + ) + assert model_trainer.hyperparameters == expected_hyperparameters + mock_file_open.assert_called_once_with("hyperparameters.json", "r") + mock_exists.assert_called_once_with("hyperparameters.json") + + +@patch("os.path.exists") +def test_hyperparameters_valid_yaml(mock_exists, modules_session): + mock_exists.return_value = True + expected_hyperparameters = {"param1": "value1", "param2": 2} + mock_file_open = mock_open(read_data=yaml.dump(expected_hyperparameters)) + + with patch("builtins.open", mock_file_open): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + assert model_trainer.hyperparameters == expected_hyperparameters + mock_file_open.assert_called_once_with("hyperparameters.yaml", "r") + mock_exists.assert_called_once_with("hyperparameters.yaml") + + +def test_hyperparameters_not_exist(modules_session): + with pytest.raises(ValueError): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="nonexistent.json", + ) + + +@patch("os.path.exists") +def test_hyperparameters_invalid(mock_exists, modules_session): + mock_exists.return_value = True + + # YAML contents must be a valid mapping + mock_file_open = mock_open(read_data="- item1\n- item2") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + + # YAML contents must be a valid mapping + mock_file_open = mock_open(read_data="invalid") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + + # Must be valid YAML + mock_file_open = mock_open(read_data="* invalid") + with patch("builtins.open", mock_file_open): + with pytest.raises(ValueError, match="Must be a valid JSON or YAML file."): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + hyperparameters="hyperparameters.yaml", + ) + + +@patch("sagemaker.modules.train.model_trainer._get_unique_name") +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_model_trainer_default_paths(mock_training_job, mock_unique_name, modules_session): + def mock_upload_data(path, bucket, key_prefix): + return f"s3://{bucket}/{key_prefix}" + + unique_name = "base-job-0123456789" + base_name = "base-job" + + modules_session.upload_data.side_effect = mock_upload_data + mock_unique_name.return_value = unique_name + + model_trainer = ( + ModelTrainer( + training_image=DEFAULT_IMAGE, + sagemaker_session=modules_session, + base_job_name=base_name, + ) + .with_tensorboard_output_config() + .with_checkpoint_config() + ) + + model_trainer.train() + + _, kwargs = mock_training_job.create.call_args + + default_base_path = f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{base_name}" + + assert kwargs["output_data_config"].s3_output_path == default_base_path + assert kwargs["output_data_config"].compression_type == "GZIP" + + assert kwargs["checkpoint_config"].s3_uri == f"{default_base_path}/{unique_name}/checkpoints" + assert kwargs["checkpoint_config"].local_path == "/opt/ml/checkpoints" + + assert kwargs["tensor_board_output_config"].s3_output_path == default_base_path + assert kwargs["tensor_board_output_config"].local_path == "/opt/ml/output/tensorboard" + + +def test_create_training_job_args(modules_session): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + ) + + args = model_trainer._create_training_job_args() + assert args["algorithm_specification"] == AlgorithmSpecification( + training_image=DEFAULT_IMAGE, + algorithm_name=None, + training_input_mode="File", + container_entrypoint=None, + container_arguments=None, + training_image_config=None, + metric_definitions=None, + ) + assert args["resource_config"] == ResourceConfig( + instance_type=DEFAULT_INSTANCE_TYPE, + instance_count=1, + volume_size_in_gb=30, + ) + assert args["role_arn"] == DEFAULT_ROLE + + +def test_create_training_job_args_boto3(modules_session): + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + ) + + args = model_trainer._create_training_job_args(boto3=True) + assert args["AlgorithmSpecification"] == { + "TrainingImage": DEFAULT_IMAGE, + "TrainingInputMode": "File", + } + assert args["ResourceConfig"] == { + "InstanceType": DEFAULT_INSTANCE_TYPE, + "InstanceCount": 1, + "VolumeSizeInGB": 30, + } + assert args["RoleArn"] == DEFAULT_ROLE + + +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_input_merge(mock_training_job, modules_session): + model_input = InputData(channel_name="model", data_source="s3://bucket/model/model.tar.gz") + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + input_data_config=[model_input], + ) + + train_input = InputData(channel_name="train", data_source="s3://bucket/data/train") + model_trainer.train(input_data_config=[train_input]) + + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["input_data_config"] == [ + Channel( + channel_name="model", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri="s3://bucket/model/model.tar.gz", + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ), + Channel( + channel_name="train", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri="s3://bucket/data/train", + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ), + ] + + +@patch("sagemaker.modules.train.model_trainer._get_unique_name") +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_nova_recipe(mock_training_job, mock_unique_name, modules_session): + def mock_upload_data(path, bucket, key_prefix): + if os.path.isfile(path): + file_name = os.path.basename(path) + return f"s3://{bucket}/{key_prefix}/{file_name}" + else: + return f"s3://{bucket}/{key_prefix}" + + unique_name = "base-job-0123456789" + base_name = "base-job" + + modules_session.upload_data.side_effect = mock_upload_data + mock_unique_name.return_value = unique_name + + recipe_data = { + "run": { + "name": "dummy-model", + "model_type": "amazon.nova", + "model_name_or_path": "dummy-model", + } + } + with NamedTemporaryFile(suffix=".yaml", delete=False) as recipe: + with open(recipe.name, "w") as file: + yaml.dump(recipe_data, file) + + trainer = ModelTrainer.from_recipe( + training_recipe=recipe.name, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + training_image=DEFAULT_IMAGE, + base_job_name=base_name, + ) + + assert trainer._is_nova_recipe + + trainer.train() + mock_training_job.create.assert_called_once() + assert mock_training_job.create.call_args.kwargs["hyper_parameters"] == { + "base_model": "dummy-model", + "sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH, + } + + default_base_path = f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{base_name}" + assert mock_training_job.create.call_args.kwargs["input_data_config"] == [ + Channel( + channel_name="recipe", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"{default_base_path}/{unique_name}/input/recipe/recipe.yaml", + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ) + ] + + +def test_nova_recipe_with_distillation(modules_session): + recipe_data = {"training_config": {"distillation_data": "true", "kms_key": "alias/my-kms-key"}} + + with NamedTemporaryFile(suffix=".yaml", delete=False) as recipe: + with open(recipe.name, "w") as file: + yaml.dump(recipe_data, file) + + # Create ModelTrainer from recipe + trainer = ModelTrainer.from_recipe( + training_recipe=recipe.name, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + training_image=DEFAULT_IMAGE, + ) + + # Verify that the hyperparameters were set correctly + assert trainer.hyperparameters == { + "distillation_data": "true", + "role_arn": DEFAULT_ROLE, + "kms_key": "alias/my-kms-key", + } + + # Clean up the temporary file + os.unlink(recipe.name) diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index 53119e532a..bdbba955a4 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -89,6 +89,7 @@ subnets=SUBNETS, ) CRON_HOURLY = CronExpressionGenerator.hourly() +CRON_NOW = CronExpressionGenerator.now() ENDPOINT_NAME = "endpoint" GROUND_TRUTH_S3_URI = "s3://bucket/monitoring_captured/actuals" ANALYSIS_CONFIG_S3_URI = "s3://bucket/analysis_config.json" @@ -568,11 +569,12 @@ def test_clarify_model_monitor(): # The subclass should has monitoring_type() defined # noinspection PyAbstractClass - class DummyClarifyModelMonitoir(ClarifyModelMonitor): + class DummyClarifyModelMonitor(ClarifyModelMonitor): + _TEST_CLASS = True pass with pytest.raises(TypeError): - DummyClarifyModelMonitoir.monitoring_type() + DummyClarifyModelMonitor.monitoring_type() def test_clarify_model_monitor_invalid_update(clarify_model_monitors): @@ -593,6 +595,8 @@ def test_clarify_model_monitor_invalid_attach(sagemaker_session): ) # attach, invalid monitoring type for clarify_model_monitor_cls in ClarifyModelMonitor.__subclasses__(): + if hasattr(clarify_model_monitor_cls, "_TEST_CLASS"): + continue with pytest.raises(TypeError): clarify_model_monitor_cls.attach(SCHEDULE_NAME, sagemaker_session) @@ -1302,6 +1306,66 @@ def test_model_explainability_monitor(model_explainability_monitor, sagemaker_se ) +def test_model_explainability_create_one_time_schedule( + model_explainability_monitor, sagemaker_session +): + endpoint_input = EndpointInput( + endpoint_name=ENDPOINT_NAME, + destination=ENDPOINT_INPUT_LOCAL_PATH, + features_attribute=FEATURES_ATTRIBUTE, + inference_attribute=str(INFERENCE_ATTRIBUTE), + ) + + # Create one-time schedule + with patch( + "sagemaker.s3.S3Uploader.upload_string_as_file_body", return_value=ANALYSIS_CONFIG_S3_URI + ) as _: + model_explainability_monitor.create_monitoring_schedule( + endpoint_input=endpoint_input, + analysis_config=ANALYSIS_CONFIG_S3_URI, + output_s3_uri=OUTPUT_S3_URI, + monitor_schedule_name=SCHEDULE_NAME, + schedule_cron_expression=CRON_NOW, + data_analysis_start_time=START_TIME_OFFSET, + data_analysis_end_time=END_TIME_OFFSET, + ) + + # Validate job definition creation + sagemaker_session.sagemaker_client.create_model_explainability_job_definition.assert_called_once() + job_definition_args = ( + sagemaker_session.sagemaker_client.create_model_explainability_job_definition.call_args[1] + ) + assert ( + job_definition_args["JobDefinitionName"] == model_explainability_monitor.job_definition_name + ) + assert job_definition_args == { + "JobDefinitionName": model_explainability_monitor.job_definition_name, + **EXPLAINABILITY_JOB_DEFINITION, + "Tags": TAGS, + } + + # Validate monitoring schedule creation + sagemaker_session.sagemaker_client.create_monitoring_schedule.assert_called_once() + schedule_args = sagemaker_session.sagemaker_client.create_monitoring_schedule.call_args[1] + assert schedule_args == { + "MonitoringScheduleName": SCHEDULE_NAME, + "MonitoringScheduleConfig": { + "MonitoringJobDefinitionName": model_explainability_monitor.job_definition_name, + "MonitoringType": "ModelExplainability", + "ScheduleConfig": { + "ScheduleExpression": CRON_NOW, + "DataAnalysisStartTime": START_TIME_OFFSET, + "DataAnalysisEndTime": END_TIME_OFFSET, + }, + }, + "Tags": TAGS, + } + + # Check if the monitoring schedule is stored in the monitor object + assert model_explainability_monitor.monitoring_schedule_name == SCHEDULE_NAME + assert model_explainability_monitor.job_definition_name is not None + + def test_model_explainability_batch_transform_monitor( model_explainability_monitor, sagemaker_session ): diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index d31b9f8527..b338885491 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -73,6 +73,7 @@ LINFINITY_METHOD = "LInfinity" CRON_DAILY = CronExpressionGenerator.daily() +CRON_NOW = CronExpressionGenerator.now() BASELINING_JOB_NAME = "baselining-job" BASELINE_DATASET_PATH = "/my/local/path/baseline.csv" PREPROCESSOR_PATH = "/my/local/path/preprocessor.py" @@ -1136,6 +1137,36 @@ def _test_data_quality_monitor_update_schedule(data_quality_monitor, sagemaker_s sagemaker_session.sagemaker_client.delete_data_quality_job_definition.assert_not_called() sagemaker_session.sagemaker_client.create_data_quality_job_definition.assert_not_called() + # update schedule + sagemaker_session.describe_monitoring_schedule = MagicMock() + sagemaker_session.sagemaker_client.describe_data_quality_job_definition = MagicMock() + sagemaker_session.sagemaker_client.create_data_quality_job_definition = MagicMock() + + # Test updating monitoring schedule with schedule_cron_expression set to NOW + sagemaker_session.sagemaker_client.update_monitoring_schedule = Mock() + data_quality_monitor.update_monitoring_schedule( + data_analysis_start_time="-PT24H", + data_analysis_end_time="-PT0H", + schedule_cron_expression=CRON_NOW, + ) + + sagemaker_session.sagemaker_client.update_monitoring_schedule.assert_called_once_with( + MonitoringScheduleName=data_quality_monitor.monitoring_schedule_name, + MonitoringScheduleConfig={ + "MonitoringJobDefinitionName": data_quality_monitor.job_definition_name, + "MonitoringType": DefaultModelMonitor.monitoring_type(), + "ScheduleConfig": { + "ScheduleExpression": CRON_NOW, + "DataAnalysisStartTime": "-PT24H", + "DataAnalysisEndTime": "-PT0H", + }, + }, + ) + + # A new data quality job definition should be created + sagemaker_session.sagemaker_client.describe_data_quality_job_definition.assert_called_once() + sagemaker_session.sagemaker_client.create_data_quality_job_definition.assert_called_once() + # update one property of job definition time.sleep( 0.001 diff --git a/tests/unit/sagemaker/partner_app/__init__.py b/tests/unit/sagemaker/partner_app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/partner_app/test_auth_provider.py b/tests/unit/sagemaker/partner_app/test_auth_provider.py new file mode 100644 index 0000000000..c5a27cff3a --- /dev/null +++ b/tests/unit/sagemaker/partner_app/test_auth_provider.py @@ -0,0 +1,152 @@ +from __future__ import absolute_import + +import os +import unittest +from unittest.mock import patch, MagicMock +from requests import PreparedRequest +from sagemaker.partner_app.auth_provider import RequestsAuth, PartnerAppAuthProvider + + +class TestRequestsAuth(unittest.TestCase): + + @patch("sagemaker.partner_app.auth_provider.PartnerAppAuthUtils.get_signed_request") + @patch("sagemaker.partner_app.auth_provider.SigV4Auth") + def test_requests_auth_call(self, mock_sigv4_auth, mock_get_signed_request): + # Prepare mock data + mock_signed_url = "https://returned-url.test.com/" + mock_signed_headers = {"Authorization": "SigV4", "x-amz-date": "20241016T120000Z"} + mock_get_signed_request.return_value = (mock_signed_url, mock_signed_headers) + + # Create the objects needed for testing + app_arn = "arn:aws:lambda:us-west-2:123456789012:sagemaker:test" + under_test = RequestsAuth(sigv4=mock_sigv4_auth, app_arn=app_arn) + + # Create a prepared request object to simulate an actual request + request = PreparedRequest() + request.method = "GET" + request_url = "https://test.com" + request.url = request_url + request_headers = {} + request.headers = request_headers + request.body = "{}" + + # Call the method under test + updated_request = under_test(request) + + # Assertions to verify the behavior + mock_get_signed_request.assert_called_once_with( + sigv4=mock_sigv4_auth, + app_arn=app_arn, + url=request_url, + method="GET", + headers=request_headers, + body=request.body, + ) + + self.assertEqual(updated_request.url, mock_signed_url) + self.assertIn("Authorization", updated_request.headers) + self.assertIn("x-amz-date", updated_request.headers) + self.assertEqual(updated_request.headers["Authorization"], "SigV4") + self.assertEqual(updated_request.headers["x-amz-date"], "20241016T120000Z") + + +class TestPartnerAppAuthProvider(unittest.TestCase): + + @patch("sagemaker.partner_app.auth_provider.boto3.Session") + @patch("sagemaker.partner_app.auth_provider.SigV4Auth") + @patch("sagemaker.partner_app.auth_provider.PartnerAppAuthUtils.get_signed_request") + def test_get_signed_request( + self, mock_get_signed_request, mock_sigv4auth_class, mock_boto3_session + ): + # Set up environment variable + test_app_arn = "arn:aws-us-gov:sagemaker:us-west-2:123456789012:partner-app/my-app" + os.environ["AWS_PARTNER_APP_ARN"] = test_app_arn + + # Mock the return value of boto3.Session().get_credentials() + mock_credentials = MagicMock() + mock_boto3_session.return_value.get_credentials.return_value = mock_credentials + + # Mock the SigV4Auth instance + mock_sigv4auth_instance = MagicMock() + mock_sigv4auth_class.return_value = mock_sigv4auth_instance + + # Initialize the PartnerAppAuthProvider class + provider = PartnerAppAuthProvider() + + # Mock return value for get_signed_request + mock_get_signed_request.return_value = { + "Authorization": "SigV4", + "x-amz-date": "20241016T120000Z", + } + + # Call get_signed_request method + signed_request = provider.get_signed_request( + url="https://example.com", + method="GET", + headers={"Content-Type": "application/json"}, + body=None, + ) + + # Assert that the get_signed_request method was called with correct parameters + mock_get_signed_request.assert_called_once_with( + sigv4=mock_sigv4auth_instance, + app_arn=test_app_arn, + url="https://example.com", + method="GET", + headers={"Content-Type": "application/json"}, + body=None, + ) + + # Assert the response matches the mocked return value + self.assertEqual(signed_request["Authorization"], "SigV4") + self.assertEqual(signed_request["x-amz-date"], "20241016T120000Z") + + @patch("sagemaker.partner_app.auth_provider.SigV4Auth") + def test_get_auth(self, mock_sigv4auth_class): + # Set up environment variable + os.environ["AWS_PARTNER_APP_ARN"] = ( + "arn:aws:sagemaker:us-west-2:123456789012:partner-app/app-abc" + ) + + # Mock the SigV4Auth instance + mock_sigv4auth_instance = MagicMock() + mock_sigv4auth_class.return_value = mock_sigv4auth_instance + + # Initialize the PartnerAppAuthProvider class + provider = PartnerAppAuthProvider() + + # Call get_auth method + auth_instance = provider.get_auth() + + # Assert that the returned object is a RequestsAuth instance + self.assertIsInstance(auth_instance, RequestsAuth) + + # Assert that RequestsAuth was initialized with correct arguments + self.assertEqual(auth_instance.sigv4, mock_sigv4auth_instance) + self.assertEqual(auth_instance.app_arn, os.environ["AWS_PARTNER_APP_ARN"]) + + def test_init_raises_value_error_with_missing_app_arn(self): + # Remove the environment variable + if "AWS_PARTNER_APP_ARN" in os.environ: + del os.environ["AWS_PARTNER_APP_ARN"] + + # Ensure ValueError is raised when AWS_PARTNER_APP_ARN is not set + with self.assertRaises(ValueError) as context: + PartnerAppAuthProvider() + + self.assertIn( + "Must specify the AWS_PARTNER_APP_ARN environment variable", str(context.exception) + ) + + def test_init_raises_value_error_with_invalid_app_arn(self): + os.environ["AWS_PARTNER_APP_ARN"] = ( + "arn:aws:lambda:us-west-2:123456789012:function:my-function" + ) + + # Ensure ValueError is raised when AWS_PARTNER_APP_ARN is not set + with self.assertRaises(ValueError) as context: + PartnerAppAuthProvider() + + self.assertIn( + "Must specify a valid AWS_PARTNER_APP_ARN environment variable", str(context.exception) + ) diff --git a/tests/unit/sagemaker/partner_app/test_auth_utils.py b/tests/unit/sagemaker/partner_app/test_auth_utils.py new file mode 100644 index 0000000000..b75dc9a30e --- /dev/null +++ b/tests/unit/sagemaker/partner_app/test_auth_utils.py @@ -0,0 +1,111 @@ +from __future__ import absolute_import + +import unittest +from unittest.mock import Mock, patch +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +from hashlib import sha256 + +from sagemaker.partner_app.auth_utils import ( + PartnerAppAuthUtils, + EMPTY_SHA256_HASH, + UNSIGNED_PAYLOAD, +) + + +class TestPartnerAppAuthUtils(unittest.TestCase): + def setUp(self): + self.sigv4_mock = Mock(spec=SigV4Auth) + self.app_arn = "arn:aws:sagemaker:us-west-2:123456789012:partner-app/abc123" + self.url = "https://partner-app-abc123.us-west-2.amazonaws.com?fileName=Jupyter+interactive" + self.method = "POST" + self.headers = {"Authorization": "API_KEY", "Connection": "conn"} + self.body = b'{"key": "value"}' # Byte type body for hashing + + @patch("sagemaker.partner_app.auth_utils.AWSRequest") + def test_get_signed_request_with_body(self, AWSRequestMock): + aws_request_mock = Mock(spec=AWSRequest) + AWSRequestMock.return_value = aws_request_mock + + expected_hash = sha256(self.body).hexdigest() + # Authorization still has the original value as the sigv4 mock does not add this header + expected_sign_headers = { + "Authorization": "API_KEY", + "X-Amz-Partner-App-Authorization": "API_KEY", + "X-SageMaker-Partner-App-Server-Arn": self.app_arn, + "X-Amz-Target": "SageMaker.CallPartnerAppApi", + "X-Amz-Content-SHA256": expected_hash, + } + aws_request_mock.headers = expected_sign_headers + + # Mock the add_auth method on the SigV4Auth + self.sigv4_mock.add_auth = Mock() + + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + self.sigv4_mock, self.app_arn, self.url, self.method, self.headers, self.body + ) + + # Assert X-SageMaker-Partner-App-Server-Arn header is correct + self.assertEqual(signed_headers["X-SageMaker-Partner-App-Server-Arn"], self.app_arn) + + # Assert the Authorization header was moved to X-Amz-Partner-App-Authorization + self.assertIn("X-Amz-Partner-App-Authorization", signed_headers) + + # Assert X-Amz-Content-SHA256 is set + self.assertEqual(signed_headers["X-Amz-Content-SHA256"], expected_hash) + + # Assert the Connection header is reserved + self.assertEqual(signed_headers["Connection"], "conn") + + expected_canonical_url = self.url.replace("+", "%20") + # Assert AWSRequestMock was called + AWSRequestMock.assert_called_once_with( + method=self.method, + url=expected_canonical_url, + headers=expected_sign_headers, + data=self.body, + ) + + def test_get_signed_request_with_no_body(self): + body = None + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + self.sigv4_mock, self.app_arn, self.url, self.method, self.headers, body + ) + + # Assert X-Amz-Content-SHA256 is EMPTY_SHA256_HASH + self.assertEqual(signed_headers["X-Amz-Content-SHA256"], EMPTY_SHA256_HASH) + + def test_get_signed_request_with_bytes_body(self): + body = Mock() + body.seek = Mock() + body.tell = Mock(return_value=0) + body.read = Mock(side_effect=[b"test", b""]) + + url, signed_headers = PartnerAppAuthUtils.get_signed_request( + self.sigv4_mock, self.app_arn, self.url, self.method, self.headers, body + ) + + # Verify the seek method was called + body.seek.assert_called() + + # Calculate the expected checksum for the body + checksum = sha256(b"test").hexdigest() + + # Assert X-Amz-Content-SHA256 is the calculated checksum + self.assertEqual(signed_headers["X-Amz-Content-SHA256"], checksum) + + def test_get_body_header_unsigned_payload(self): + body = {"key": "value"} + + result = PartnerAppAuthUtils.get_body_header(body) + + # Assert the result is UNSIGNED_PAYLOAD for unrecognized body type + self.assertEqual(result, UNSIGNED_PAYLOAD) + + def test_get_body_header_empty_body(self): + body = None + + result = PartnerAppAuthUtils.get_body_header(body) + + # Assert the result is EMPTY_SHA256_HASH for empty body + self.assertEqual(result, EMPTY_SHA256_HASH) diff --git a/tests/unit/sagemaker/remote_function/core/test_serialization.py b/tests/unit/sagemaker/remote_function/core/test_serialization.py index a0742240ea..e87dc39b59 100644 --- a/tests/unit/sagemaker/remote_function/core/test_serialization.py +++ b/tests/unit/sagemaker/remote_function/core/test_serialization.py @@ -99,6 +99,7 @@ def test_serialize_deserialize_lambda(): assert deserialized(3) == 9 +@pytest.mark.flaky(reruns=3, reruns_delay=5) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @patch("sagemaker.s3.S3Downloader.read_bytes", new=read) @patch("sagemaker.experiments.run.Experiment") @@ -106,9 +107,11 @@ def test_serialize_deserialize_lambda(): @patch("sagemaker.experiments.run._TrialComponent._load_or_create", return_value=(Mock(), False)) @patch("sagemaker.experiments.run._MetricsManager") @patch("sagemaker.remote_function.job.Session") -def test_serialize_func_referencing_to_run(*args, **kwargs): +def test_serialize_func_referencing_to_run(sagemaker_session, *args, **kwargs): - with Run(experiment_name="exp_name", run_name="run_name") as run: + with Run( + sagemaker_session=sagemaker_session, experiment_name="exp_name", run_name="run_name" + ) as run: def train(x): return run.log_metric() @@ -302,6 +305,7 @@ def test_serialize_deserialize_none(): assert deserialized is None +@pytest.mark.flaky(reruns=3, reruns_delay=5) @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @patch("sagemaker.s3.S3Downloader.read_bytes", new=read) @patch("sagemaker.experiments.run.Experiment") @@ -309,8 +313,10 @@ def test_serialize_deserialize_none(): @patch("sagemaker.experiments.run._TrialComponent._load_or_create", return_value=(Mock(), False)) @patch("sagemaker.experiments.run._MetricsManager") @patch("sagemaker.remote_function.job.Session") -def test_serialize_run(*args, **kwargs): - with Run(experiment_name="exp_name", run_name="run_name") as run: +def test_serialize_run(sagemaker_session, *args, **kwargs): + with Run( + sagemaker_session=sagemaker_session, experiment_name="exp_name", run_name="run_name" + ) as run: s3_uri = random_s3_uri() with pytest.raises( SerializationError, diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py new file mode 100644 index 0000000000..aa983141ae --- /dev/null +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_mpi_utils.py @@ -0,0 +1,125 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""MPI Utils Unit Tests.""" +from __future__ import absolute_import + +import os +from mock import patch + +import sagemaker.remote_function.runtime_environment.mpi_utils_remote as mpi_utils_remote # noqa: E402 + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-1", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +def test_mpi_utils_main_job_start( + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main() + + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_called_once() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-2", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +def test_mpi_utils_worker_job_start( + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main() + + mock_start_sshd_daemon.assert_called_once() + mock_bootstrap_worker_node.assert_called_once() + mock_bootstrap_master_node.assert_not_called() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-1", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +@patch( + "sagemaker.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers" +) +def test_mpi_utils_main_job_end( + mock_write_status_file_to_workers, + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main(["--job_ended", "1"]) + + mock_start_sshd_daemon.assert_not_called() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_not_called() + mock_write_status_file_to_workers.assert_called_once() + + +@patch.dict( + os.environ, + { + "SM_MASTER_ADDR": "algo-1", + "SM_CURRENT_HOST": "algo-2", + "SM_HOSTS": '["algo-1", "algo-2"]', + }, +) +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_master_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.bootstrap_worker_node") +@patch("sagemaker.remote_function.runtime_environment.mpi_utils_remote.start_sshd_daemon") +@patch( + "sagemaker.remote_function.runtime_environment.mpi_utils_remote.write_status_file_to_workers" +) +def test_mpi_utils_worker_job_end( + mock_write_status_file_to_workers, + mock_start_sshd_daemon, + mock_bootstrap_worker_node, + mock_bootstrap_master_node, +): + + mpi_utils_remote.main(["--job_ended", "1"]) + + mock_start_sshd_daemon.assert_not_called() + mock_bootstrap_worker_node.assert_not_called() + mock_bootstrap_master_node.assert_not_called() + mock_write_status_file_to_workers.assert_not_called() diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index 20d05a933e..de8758bfad 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -15,6 +15,7 @@ import os import threading import time +import inspect import pytest from mock import MagicMock, patch, Mock, ANY, call @@ -1498,17 +1499,20 @@ def test_consistency_between_remote_and_step_decorator(): from sagemaker.workflow.function_step import step remote_args_to_ignore = [ - "_remote", "include_local_workdir", "custom_file_filter", "s3_kms_key", "s3_root_uri", "sagemaker_session", + "disable_output_compression", + "use_torchrun", + "use_mpirun", + "nproc_per_node", ] step_args_to_ignore = ["_step", "name", "display_name", "description", "retry_policies"] - remote_decorator_args = remote.__code__.co_varnames + remote_decorator_args = inspect.signature(remote).parameters.keys() common_remote_decorator_args = set(remote_args_to_ignore) ^ set(remote_decorator_args) step_decorator_args = step.__code__.co_varnames @@ -1522,8 +1526,7 @@ def test_consistency_between_remote_and_executor(): executor_arg_list.remove("self") executor_arg_list.remove("max_parallel_jobs") - remote_args_list = list(remote.__code__.co_varnames) - remote_args_list.remove("_remote") + remote_args_list = list(inspect.signature(remote).parameters.keys()) remote_args_list.remove("_func") assert executor_arg_list == remote_args_list diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index 98961ad80d..f153b5b2ca 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -15,6 +15,7 @@ import os import sys +import tempfile import pytest from mock import patch, Mock, ANY, mock_open from mock.mock import MagicMock @@ -49,6 +50,11 @@ _prepare_dependencies_and_pre_execution_scripts, ) +from sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment import ( + set_env, + safe_serialize, +) + REGION = "us-west-2" TRAINING_JOB_ARN = "training-job-arn" @@ -68,6 +74,178 @@ EXPECTED_OUTPUT_URI = S3_URI + "/output" EXPECTED_DEPENDENCIES_URI = S3_URI + "/additional_dependencies/requirements.txt" +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_CPU = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.t3.xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='4' +export SM_NUM_GPUS='0' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='4' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.t3.xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 4, "num_gpus": 0, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.t3.xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.t3.xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +""" + +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='48' +export SM_NUM_GPUS='4' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='4' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export NCCL_SOCKET_IFNAME='eth0' +export NCCL_PROTO='simple' +""" + +# flake8: noqa +EXPECTED_ENV_MULTI_NODE_MULTI_GPUS = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.2xlarge' +export SM_HOSTS='["algo-1", "algo-2", "algo-3", "algo-4"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='4' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='8' +export SM_NUM_GPUS='1' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='1' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.2xlarge", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "host_count": 4, "nproc_per_node": 1, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 8, "num_gpus": 1, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}}' +export NCCL_SOCKET_IFNAME='eth0' +export NCCL_PROTO='simple' +""" + +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='48' +export SM_NUM_GPUS='4' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='4' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 4, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:4' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + +# flake8: noqa +EXPECTED_ENV_MULTI_NODE_MULTI_GPUS_MPIRUN = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.2xlarge' +export SM_HOSTS='["algo-1", "algo-2", "algo-3", "algo-4"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='4' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='8' +export SM_NUM_GPUS='1' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='1' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.2xlarge", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "host_count": 4, "nproc_per_node": 1, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 8, "num_gpus": 1, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1", "algo-2", "algo-3", "algo-4"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.2xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.2xlarge", "hosts": ["algo-4", "algo-2", "algo-1", "algo-3"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:1,algo-2:1,algo-3:1,algo-4:1' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + +# flake8: noqa +EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN_WITH_NPROC_PER_NODE = """ +export SM_MODEL_DIR='/opt/ml/model' +export SM_INPUT_DIR='/opt/ml/input' +export SM_INPUT_DATA_DIR='/opt/ml/input/data' +export SM_INPUT_CONFIG_DIR='/opt/ml/input/config' +export SM_OUTPUT_DIR='/opt/ml/output' +export SM_OUTPUT_FAILURE='/opt/ml/output/failure' +export SM_OUTPUT_DATA_DIR='/opt/ml/output/data' +export SM_MASTER_ADDR='algo-1' +export SM_MASTER_PORT='7777' +export SM_CURRENT_HOST='algo-1' +export SM_CURRENT_INSTANCE_TYPE='ml.g5.12xlarge' +export SM_HOSTS='["algo-1"]' +export SM_NETWORK_INTERFACE_NAME='eth0' +export SM_HOST_COUNT='1' +export SM_CURRENT_HOST_RANK='0' +export SM_NUM_CPUS='48' +export SM_NUM_GPUS='4' +export SM_NUM_NEURONS='0' +export SM_RESOURCE_CONFIG='{"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}' +export SM_NPROC_PER_NODE='2' +export SM_TRAINING_ENV='{"current_host": "algo-1", "current_instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"], "host_count": 1, "nproc_per_node": 2, "master_addr": "algo-1", "master_port": 7777, "input_config_dir": "/opt/ml/input/config", "input_data_dir": "/opt/ml/input/data", "input_dir": "/opt/ml/input", "job_name": "test-job", "model_dir": "/opt/ml/model", "network_interface_name": "eth0", "num_cpus": 48, "num_gpus": 4, "num_neurons": 0, "output_data_dir": "/opt/ml/output/data", "resource_config": {"current_host": "algo-1", "hosts": ["algo-1"], "current_group_name": "homogeneousCluster", "current_instance_type": "ml.g5.12xlarge", "instance_groups": [{"instance_group_name": "homogeneousCluster", "instance_type": "ml.g5.12xlarge", "hosts": ["algo-1"]}], "network_interface_name": "eth0"}}' +export MASTER_ADDR='algo-1' +export MASTER_PORT='7777' +export SM_HOSTS_LIST='algo-1:2' +export SM_FI_PROVIDER='' +export SM_NCCL_PROTO='' +export SM_FI_EFA_USE_DEVICE_RDMA='' +""" + DESCRIBE_TRAINING_JOB_RESPONSE = { "TrainingJobArn": TRAINING_JOB_ARN, "TrainingJobStatus": "{}", @@ -112,8 +290,8 @@ def mock_get_current_run(): return current_run -def describe_training_job_response(job_status): - return { +def describe_training_job_response(job_status, disable_output_compression=False): + job_response = { "TrainingJobArn": TRAINING_JOB_ARN, "TrainingJobStatus": job_status, "ResourceConfig": { @@ -121,15 +299,38 @@ def describe_training_job_response(job_status): "InstanceType": "ml.c4.xlarge", "VolumeSizeInGB": 30, }, - "OutputDataConfig": {"S3OutputPath": "s3://sagemaker-123/image_uri/output"}, } + if disable_output_compression: + output_config = { + "S3OutputPath": "s3://sagemaker-123/image_uri/output", + "CompressionType": "NONE", + } + else: + output_config = { + "S3OutputPath": "s3://sagemaker-123/image_uri/output", + "CompressionType": "NONE", + } + + job_response["OutputDataConfig"] = output_config + + return job_response + COMPLETED_TRAINING_JOB = describe_training_job_response("Completed") INPROGRESS_TRAINING_JOB = describe_training_job_response("InProgress") CANCELLED_TRAINING_JOB = describe_training_job_response("Stopped") FAILED_TRAINING_JOB = describe_training_job_response("Failed") +COMPLETED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response( + "Completed", True +) +INPROGRESS_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response( + "InProgress", True +) +CANCELLED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response("Stopped", True) +FAILED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION = describe_training_job_response("Failed", True) + def mock_session(): session = Mock() @@ -389,6 +590,8 @@ def test_start( s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=None, sagemaker_session=session(), + use_torchrun=False, + use_mpirun=False, ) mock_dependency_upload.assert_called_once_with( @@ -670,6 +873,8 @@ def test_start_with_complete_job_settings( s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), + use_torchrun=False, + use_mpirun=False, ) mock_user_workspace_upload.assert_called_once_with( @@ -840,6 +1045,8 @@ def test_get_train_args_under_pipeline_context( s3_base_uri=s3_base_uri, s3_kms_key=job_settings.s3_kms_key, sagemaker_session=session(), + use_torchrun=False, + use_mpirun=False, ) mock_user_workspace_upload.assert_called_once_with( @@ -1014,6 +1221,8 @@ def test_start_with_spark( s3_base_uri=f"{S3_URI}/{job.job_name}", s3_kms_key=None, sagemaker_session=session(), + use_torchrun=False, + use_mpirun=False, ) session().sagemaker_client.create_training_job.assert_called_once_with( @@ -1116,6 +1325,27 @@ def test_describe(session, *args): session().sagemaker_client.describe_training_job.assert_called_once() +@patch("sagemaker.remote_function.job._prepare_and_upload_runtime_scripts") +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_describe_disable_output_compression(session, *args): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + instance_type="ml.m5.large", + disable_output_compression=True, + ) + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + job.describe() + assert job.describe() == COMPLETED_TRAINING_JOB_DISABLE_OUTPUT_COMPRESSION + + session().sagemaker_client.describe_training_job.assert_called_once() + + @patch("sagemaker.remote_function.job._prepare_and_upload_runtime_scripts") @patch("sagemaker.remote_function.job._prepare_and_upload_workspace") @patch("sagemaker.remote_function.job.StoredFunction") @@ -1172,7 +1402,7 @@ def test_prepare_and_upload_runtime_scripts(session, mock_copy, mock_s3_upload): assert s3_path == mock_s3_upload.return_value - assert mock_copy.call_count == 2 + assert mock_copy.call_count == 3 mock_s3_upload.assert_called_once() @@ -1192,7 +1422,7 @@ def test_prepare_and_upload_runtime_scripts_under_pipeline_context( ) # Bootstrap scripts are uploaded on the first call assert s3_path == mock_s3_upload.return_value - assert mock_copy.call_count == 2 + assert mock_copy.call_count == 3 mock_s3_upload.assert_called_once() mock_copy.reset_mock() @@ -1601,3 +1831,848 @@ def test_extend_spark_config_to_request( } ], ) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_torchrun_single_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.g5.12xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=True, + use_mpirun=False, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=True, + use_mpirun=False, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "torchrun", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.g5.12xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_torchrun_multi_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_count=2, + instance_type="ml.g5.2xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=True, + use_mpirun=False, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=True, + use_mpirun=False, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + "S3DataDistributionType": "FullyReplicated", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + "S3DataDistributionType": "FullyReplicated", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "torchrun", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=2, + InstanceType="ml.g5.2xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_cpu( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.t3.xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.t3.xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution=None, + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_CPU) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=48, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_multi_gpu( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.12xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.12xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution="torchrun", + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=8, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=1, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_multi_node_multi_gpu( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1", "algo-2", "algo-3", "algo-4"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.2xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.2xlarge", + hosts=["algo-4", "algo-2", "algo-1", "algo-3"], + ) + ], + network_interface_name="eth0", + ), + distribution="torchrun", + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_MULTI_NODE_MULTI_GPUS) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=48, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_multi_gpu_mpirun( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.12xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.12xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=8, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=1, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_multi_node_multi_gpu_mpirun( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1", "algo-2", "algo-3", "algo-4"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.2xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.2xlarge", + hosts=["algo-4", "algo-2", "algo-1", "algo-3"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines(EXPECTED_ENV_MULTI_NODE_MULTI_GPUS_MPIRUN) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_torchrun_single_node_with_nproc_per_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.g5.12xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=True, + use_mpirun=False, + nproc_per_node=2, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=True, + use_mpirun=False, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "torchrun", + "--user_nproc_per_node", + "2", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.g5.12xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch("sagemaker.experiments._run_context._RunContext.get_current_run", new=mock_get_current_run) +@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("sagemaker.remote_function.job._prepare_and_upload_workspace", return_value="some_s3_uri") +@patch( + "sagemaker.remote_function.job._prepare_and_upload_runtime_scripts", return_value="some_s3_uri" +) +@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager") +@patch("sagemaker.remote_function.job.StoredFunction") +@patch("sagemaker.remote_function.job.Session", return_value=mock_session()) +def test_start_with_mpirun_single_node_with_nproc_per_node( + session, + mock_stored_function, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + secret_token, +): + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=ROLE_ARN, + include_local_workdir=True, + instance_type="ml.g5.12xlarge", + encrypt_inter_container_traffic=True, + use_torchrun=False, + use_mpirun=True, + nproc_per_node=2, + ) + + job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) + + assert job.job_name.startswith("job-function") + + mock_stored_function.assert_called_once_with( + sagemaker_session=session(), + s3_base_uri=f"{S3_URI}/{job.job_name}", + hmac_key=HMAC_KEY, + s3_kms_key=None, + ) + + mock_stored_function().save.assert_called_once_with(job_function, *(1, 2), **{"c": 3, "d": 4}) + + local_dependencies_path = mock_runtime_manager().snapshot() + mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() + + mock_script_upload.assert_called_once_with( + spark_config=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + use_torchrun=False, + use_mpirun=True, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path=local_dependencies_path, + include_local_workdir=True, + pre_execution_commands=None, + pre_execution_script_local_path=None, + s3_base_uri=f"{S3_URI}/{job.job_name}", + s3_kms_key=None, + sagemaker_session=session(), + custom_file_filter=None, + ) + + session().sagemaker_client.create_training_job.assert_called_once_with( + TrainingJobName=job.job_name, + RoleArn=ROLE_ARN, + StoppingCondition={"MaxRuntimeInSeconds": 86400}, + RetryStrategy={"MaximumRetryAttempts": 1}, + InputDataConfig=[ + dict( + ChannelName=RUNTIME_SCRIPTS_CHANNEL_NAME, + DataSource={ + "S3DataSource": { + "S3Uri": mock_script_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + dict( + ChannelName=REMOTE_FUNCTION_WORKSPACE, + DataSource={ + "S3DataSource": { + "S3Uri": mock_dependency_upload.return_value, + "S3DataType": "S3Prefix", + } + }, + ), + ], + OutputDataConfig={"S3OutputPath": f"{S3_URI}/{job.job_name}"}, + AlgorithmSpecification=dict( + TrainingImage=IMAGE, + TrainingInputMode="File", + ContainerEntrypoint=[ + "/bin/bash", + "/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh", + ], + ContainerArguments=[ + "--s3_base_uri", + f"{S3_URI}/{job.job_name}", + "--region", + TEST_REGION, + "--client_python_version", + mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, + "--dependency_settings", + '{"dependency_file": null}', + "--distribution", + "mpirun", + "--user_nproc_per_node", + "2", + "--run_in_context", + '{"experiment_name": "my-exp-name", "run_name": "my-run-name"}', + ], + ), + ResourceConfig=dict( + VolumeSizeInGB=30, + InstanceCount=1, + InstanceType="ml.g5.12xlarge", + KeepAlivePeriodInSeconds=0, + ), + EnableNetworkIsolation=False, + EnableInterContainerTrafficEncryption=True, + EnableManagedSpotTraining=False, + Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY}, + ) + + +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_cpus", + return_value=48, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_gpus", + return_value=4, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.num_neurons", + return_value=0, +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment.safe_serialize", + side_effect=safe_serialize, +) +def test_set_env_single_node_multi_gpu_mpirun_with_nproc_per_node( + mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons +): + with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): + with tempfile.NamedTemporaryFile() as f: + set_env( + resource_config=dict( + current_host="algo-1", + hosts=["algo-1"], + current_group_name="homogeneousCluster", + current_instance_type="ml.g5.12xlarge", + instance_groups=[ + dict( + instance_group_name="homogeneousCluster", + instance_type="ml.g5.12xlarge", + hosts=["algo-1"], + ) + ], + network_interface_name="eth0", + ), + distribution="mpirun", + user_nproc_per_node=2, + output_file=f.name, + ) + + mock_num_cpus.assert_called_once() + mock_num_gpus.assert_called_once() + mock_num_neurons.assert_called_once() + + with open(f.name, "r") as f: + env_file = f.read().strip() + expected_env = _remove_extra_lines( + EXPECTED_ENV_SINGLE_NODE_MULTI_GPUS_MPIRUN_WITH_NPROC_PER_NODE + ) + env_file = _remove_extra_lines(env_file) + + assert env_file == expected_env + + +def _remove_extra_lines(string): + """Removes extra blank lines from a string.""" + return "\n".join([line for line in string.splitlines() if line.strip()]) diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 1c0cfa35b3..d149e08cab 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -56,6 +56,8 @@ def test_jumpstart_resource_requirements( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -76,6 +78,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.g5.xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -89,6 +92,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.g5.555xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -102,6 +106,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.f9.555xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -138,6 +143,8 @@ def test_jumpstart_no_supported_resource_requirements( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 16b7256ed2..6c4d6b688f 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -53,7 +53,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -71,7 +73,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -90,7 +94,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -109,7 +115,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -177,6 +185,6 @@ def test_jumpstart_artifact_bucket_override( model_version="*", ) assert ( - uri - == "s3://some-cool-bucket-name/source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz" + uri == "s3://some-cool-bucket-name/source-directory-tarballs/pytorch/" + "transfer_learning/ic/prepack/v1.1.0/sourcedir.tar.gz" ) diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_pytorch.py b/tests/unit/sagemaker/script_uris/jumpstart/test_pytorch.py index d62db7a785..80f2449901 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_pytorch.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_pytorch.py @@ -29,22 +29,22 @@ def test_jumpstart_pytorch_script_uri(patched_get_model_specs): uri = script_uris.retrieve( region="us-west-2", script_scope="inference", - model_id="pytorch-eqa-bert-base-cased", + model_id="pytorch-ic-mobilenet-v2", model_version="*", ) assert ( - uri == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" - "inference/eqa/v1.0.0/sourcedir.tar.gz" + uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/inference/ic/v2.0.0/sourcedir.tar.gz" ) # training uri = script_uris.retrieve( region="us-west-2", script_scope="training", - model_id="pytorch-eqa-bert-base-cased", + model_id="pytorch-ic-mobilenet-v2", model_version="*", ) assert ( - uri == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" - "transfer_learning/eqa/v1.0.0/sourcedir.tar.gz" + uri == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/transfer_learning/ic/prepack/v1.1.0/sourcedir.tar.gz" ) diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index 90ec5df6b5..dde308dcfb 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -53,9 +53,11 @@ def test_jumpstart_default_serializers( patched_get_model_specs.assert_called_once_with( region=region, model_id=model_id, + hub_arn=None, version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -99,6 +101,8 @@ def test_jumpstart_serializer_options( region=region, model_id=model_id, version=model_version, + hub_arn=None, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/serve/builder/test_djl_builder.py b/tests/unit/sagemaker/serve/builder/test_djl_builder.py index ccabdb86b3..9c4488fa3e 100644 --- a/tests/unit/sagemaker/serve/builder/test_djl_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_djl_builder.py @@ -15,63 +15,39 @@ import unittest from sagemaker.serve.builder.model_builder import ModelBuilder -from sagemaker.serve.utils.types import _DjlEngine from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve import ModelServer -from sagemaker.djl_inference.model import ( - DeepSpeedModel, - FasterTransformerModel, - HuggingFaceAccelerateModel, -) +from sagemaker.djl_inference.model import DJLModel from sagemaker.serve.utils.exceptions import ( LocalDeepPingException, LocalModelLoadException, LocalModelOutOfMemoryException, LocalModelInvocationException, ) -from sagemaker.serve.utils.predictors import DjlLocalModePredictor +from sagemaker.serve.utils.predictors import DjlLocalModePredictor, InProcessModePredictor from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG mock_model_id = "TheBloke/Llama-2-7b-chat-fp16" -mock_t5_model_id = "google/flan-t5-xxl" mock_prompt = "Hello, I'm a language model," mock_response = "Hello, I'm a language model, and I'm here to help you with your English." mock_sample_input = {"inputs": mock_prompt, "parameters": {}} mock_sample_output = [{"generated_text": mock_response}] -mock_expected_huggingfaceaccelerate_serving_properties = { - "engine": "Python", - "option.entryPoint": "inference.py", - "option.model_id": "TheBloke/Llama-2-7b-chat-fp16", - "option.tensor_parallel_degree": 4, - "option.dtype": "fp16", -} -mock_expected_deepspeed_serving_properties = { - "engine": "DeepSpeed", - "option.entryPoint": "inference.py", - "option.model_id": "TheBloke/Llama-2-7b-chat-fp16", - "option.tensor_parallel_degree": 4, - "option.dtype": "fp16", - "option.max_tokens": 256, - "option.triangular_masking": True, - "option.return_tuple": True, -} -mock_expected_fastertransformer_serving_properties = { - "engine": "FasterTransformer", - "option.entryPoint": "inference.py", - "option.model_id": "google/flan-t5-xxl", - "option.tensor_parallel_degree": 4, - "option.dtype": "fp16", +mock_default_configs = { + "HF_MODEL_ID": mock_model_id, + "OPTION_ENGINE": "Python", + "TENSOR_PARALLEL_DEGREE": "max", + "OPTION_DTYPE": "bf16", + "MODEL_LOADING_TIMEOUT": "1800", } mock_most_performant_serving_properties = { - "engine": "Python", - "option.entryPoint": "inference.py", - "option.model_id": "TheBloke/Llama-2-7b-chat-fp16", - "option.tensor_parallel_degree": 1, - "option.dtype": "bf16", + "OPTION_ENGINE": "Python", + "HF_MODEL_ID": "TheBloke/Llama-2-7b-chat-fp16", + "TENSOR_PARALLEL_DEGREE": "1", + "OPTION_DTYPE": "bf16", + "MODEL_LOADING_TIMEOUT": "1800", } -mock_model_config_properties = {"model_type": "llama", "num_attention_heads": 32} -mock_model_config_properties_faster_transformer = {"model_type": "t5", "num_attention_heads": 32} -mock_set_serving_properties = (4, "fp16", 1, 256, 256) +mock_inference_spec = MagicMock() +mock_inference_spec.get_model = "TheBloke/Llama-2-7b-chat-fp16" mock_schema_builder = MagicMock() mock_schema_builder.sample_input = mock_sample_input @@ -88,29 +64,23 @@ class TestDjlBuilder(unittest.TestCase): "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") + @patch( + "sagemaker.serve.builder.djl_builder._get_default_djl_configurations", + return_value=(mock_default_configs, 128), + ) def test_build_deploy_for_djl_local_container( self, + mock_default_djl_config, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, ): builder = ModelBuilder( model=mock_model_id, + name="mock_model_name", schema_builder=mock_schema_builder, mode=Mode.LOCAL_CONTAINER, model_server=ModelServer.DJL_SERVING, @@ -122,22 +92,16 @@ def test_build_deploy_for_djl_local_container( builder._prepare_for_mode.side_effect = None model = builder.build() + assert model.name == "mock_model_name" + builder.serve_settings.telemetry_opt_out = True - assert isinstance(model, HuggingFaceAccelerateModel) - assert ( - model.generate_serving_properties() - == mock_expected_huggingfaceaccelerate_serving_properties - ) - assert builder._default_tensor_parallel_degree == 4 - assert builder._default_data_type == "fp16" - assert builder._default_max_tokens == 256 - assert builder._default_max_new_tokens == 256 - assert builder.schema_builder.sample_input["parameters"]["max_new_tokens"] == 256 + assert isinstance(model, DJLModel) + assert builder.schema_builder.sample_input["parameters"]["max_new_tokens"] == 128 assert builder.nb_instance_type == "ml.g5.24xlarge" assert model.image_config == MOCK_IMAGE_CONFIG assert model.vpc_config == MOCK_VPC_CONFIG - assert "deepspeed" in builder.image_uri + assert "lmi" in builder.image_uri builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() predictor = model.deploy(model_data_download_timeout=1800) @@ -153,100 +117,64 @@ def test_build_deploy_for_djl_local_container( with self.assertRaises(ValueError) as _: model.deploy(mode=Mode.IN_PROCESS) + @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=( - _DjlEngine.FASTER_TRANSFORMER, - mock_model_config_properties_faster_transformer, - ), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) + @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") - def test_build_for_djl_local_container_faster_transformer( - self, - mock_get_nb_instance, - mock_set_serving_properties, - mock_auto_detect_engine, - mock_is_jumpstart_model, - ): - builder = ModelBuilder( - model=mock_t5_model_id, - schema_builder=mock_schema_builder, - mode=Mode.LOCAL_CONTAINER, - model_server=ModelServer.DJL_SERVING, - image_config=MOCK_IMAGE_CONFIG, - vpc_config=MOCK_VPC_CONFIG, - ) - model = builder.build() - builder.serve_settings.telemetry_opt_out = True - - assert isinstance(model, FasterTransformerModel) - assert ( - model.generate_serving_properties() - == mock_expected_fastertransformer_serving_properties - ) - assert model.image_config == MOCK_IMAGE_CONFIG - assert model.vpc_config == MOCK_VPC_CONFIG - assert "fastertransformer" in builder.image_uri - - @patch( - "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", - return_value=False, - ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.DEEPSPEED, mock_model_config_properties), - ) @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, + "sagemaker.serve.builder.djl_builder._get_default_djl_configurations", + return_value=(mock_default_configs, 128), ) - @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") - def test_build_for_djl_local_container_deepspeed( + def test_build_deploy_for_djl_in_process( self, + mock_default_djl_config, mock_get_nb_instance, - mock_set_serving_properties, - mock_auto_detect_engine, + mock_get_ram_usage_mb, mock_is_jumpstart_model, + mock_telemetry, ): builder = ModelBuilder( model=mock_model_id, + name="mock_model_name", schema_builder=mock_schema_builder, - mode=Mode.LOCAL_CONTAINER, + mode=Mode.IN_PROCESS, model_server=ModelServer.DJL_SERVING, image_config=MOCK_IMAGE_CONFIG, vpc_config=MOCK_VPC_CONFIG, ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + assert model.name == "mock_model_name" + builder.serve_settings.telemetry_opt_out = True - assert isinstance(model, DeepSpeedModel) + assert isinstance(model, DJLModel) + assert builder.schema_builder.sample_input["parameters"]["max_new_tokens"] == 128 + assert builder.nb_instance_type == "ml.g5.24xlarge" assert model.image_config == MOCK_IMAGE_CONFIG assert model.vpc_config == MOCK_VPC_CONFIG - assert model.generate_serving_properties() == mock_expected_deepspeed_serving_properties - assert "deepspeed" in builder.image_uri + assert "lmi" in builder.image_uri + + builder.modes[str(Mode.IN_PROCESS)] = MagicMock() + predictor = model.deploy(model_data_download_timeout=1800) + + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" + assert isinstance(predictor, InProcessModePredictor) + + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") @patch( @@ -261,16 +189,18 @@ def test_build_for_djl_local_container_deepspeed( "sagemaker.serve.builder.djl_builder._concurrent_benchmark", side_effect=[(0.03, 16), (0.10, 4), (0.15, 2)], ) + @patch( + "sagemaker.serve.builder.djl_builder._get_default_djl_configurations", + return_value=(mock_default_configs, 128), + ) def test_tune_for_djl_local_container( self, + mock_default_djl_config, mock_concurrent_benchmarks, mock_serial_benchmarks, mock_admissible_tensor_parallel_degrees, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, ): @@ -287,41 +217,31 @@ def test_tune_for_djl_local_container( model = builder.build() builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() - assert tuned_model.generate_serving_properties() == mock_most_performant_serving_properties + assert tuned_model.env == mock_most_performant_serving_properties @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalDeepPingException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalDeepPingException("mock_exception")}, ) @patch( "sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees", return_value=[4], ) + @patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None) def test_tune_for_djl_local_container_deep_ping_ex( self, + mock_get_available_gpus, mock_get_admissible_tensor_parallel_degrees, mock_serial_benchmarks, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, ): @@ -337,47 +257,38 @@ def test_tune_for_djl_local_container_deep_ping_ex( model = builder.build() builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() - assert ( - tuned_model.generate_serving_properties() - == mock_expected_huggingfaceaccelerate_serving_properties - ) + assert tuned_model.env == mock_default_configs + @patch("sagemaker.serve.builder.djl_builder._get_model_config_properties_from_hf") @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelLoadException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelLoadException("mock_exception")}, ) @patch( "sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees", return_value=[4], ) + @patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None) def test_tune_for_djl_local_container_load_ex( self, + mock_get_available_gpus, mock_get_admissible_tensor_parallel_degrees, mock_serial_benchmarks, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, + mock_get_model_config_properties_from_hf, ): + mock_get_model_config_properties_from_hf.return_value = {} + builder = ModelBuilder( model=mock_model_id, schema_builder=mock_schema_builder, @@ -390,44 +301,31 @@ def test_tune_for_djl_local_container_load_ex( model = builder.build() builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() - assert ( - tuned_model.generate_serving_properties() - == mock_expected_huggingfaceaccelerate_serving_properties - ) + assert tuned_model.env == mock_default_configs @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelOutOfMemoryException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelOutOfMemoryException("mock_exception")}, ) @patch( "sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees", return_value=[4], ) + @patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None) def test_tune_for_djl_local_container_oom_ex( self, + mock_get_available_gpus, mock_get_admissible_tensor_parallel_degrees, mock_serial_benchmarks, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, ): @@ -443,44 +341,31 @@ def test_tune_for_djl_local_container_oom_ex( model = builder.build() builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() - assert ( - tuned_model.generate_serving_properties() - == mock_expected_huggingfaceaccelerate_serving_properties - ) + assert tuned_model.env == mock_default_configs @patch("sagemaker.serve.builder.djl_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", return_value=False, ) - @patch( - "sagemaker.serve.builder.djl_builder._auto_detect_engine", - return_value=(_DjlEngine.HUGGINGFACE_ACCELERATE, mock_model_config_properties), - ) - @patch( - "sagemaker.serve.builder.djl_builder._set_serve_properties", - return_value=mock_set_serving_properties, - ) - @patch("sagemaker.serve.builder.djl_builder.prepare_for_djl_serving", side_effect=None) @patch("sagemaker.serve.builder.djl_builder._get_ram_usage_mb", return_value=1024) @patch("sagemaker.serve.builder.djl_builder._get_nb_instance", return_value="ml.g5.24xlarge") @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")}, ) @patch( "sagemaker.serve.builder.djl_builder._get_admissible_tensor_parallel_degrees", return_value=[4], ) + @patch("sagemaker.serve.model_server.djl_serving.utils._get_available_gpus", return_value=None) def test_tune_for_djl_local_container_invoke_ex( self, + mock_get_available_gpus, mock_get_admissible_tensor_parallel_degrees, mock_serial_benchmarks, mock_get_nb_instance, mock_get_ram_usage_mb, - mock_prepare_for_djl_serving, - mock_set_serving_properties, - mock_auto_detect_engine, mock_is_jumpstart_model, mock_telemetry, ): @@ -496,10 +381,7 @@ def test_tune_for_djl_local_container_invoke_ex( model = builder.build() builder.serve_settings.telemetry_opt_out = True tuned_model = model.tune() - assert ( - tuned_model.generate_serving_properties() - == mock_expected_huggingfaceaccelerate_serving_properties - ) + assert tuned_model.env == mock_default_configs @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 2a0c791215..415d7eab5b 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -11,10 +11,12 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, Mock import unittest +from sagemaker.enums import Tag +from sagemaker.serve import SchemaBuilder from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.mode.function_pointers import Mode from sagemaker.serve.utils.exceptions import ( @@ -23,6 +25,11 @@ LocalModelOutOfMemoryException, LocalModelInvocationException, ) +from tests.unit.sagemaker.serve.constants import ( + DEPLOYMENT_CONFIGS, + OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL, + CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES, +) mock_model_id = "huggingface-llm-amazon-falconlite" mock_t5_model_id = "google/flan-t5-xxl" @@ -63,8 +70,12 @@ "123456789712.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi" "-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" ) +mock_invalid_image_uri = ( + "123456789712.dkr.ecr.us-west-2.amazonaws.com/invalid" + "-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" +) mock_djl_image_uri = ( - "123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1" + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124" ) mock_model_data = { @@ -80,8 +91,158 @@ "/artifacts/inference-prepack/v1.0.0/" ) +mock_optimization_job_response = { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:optimization-job" + "/modelbuilderjob-c9b28846f963497ca540010b2aa2ec8d", + "OptimizationJobStatus": "COMPLETED", + "OptimizationStartTime": "", + "OptimizationEndTime": "", + "CreationTime": "", + "LastModifiedTime": "", + "OptimizationJobName": "modelbuilderjob-c9b28846f963497ca540010b2aa2ec8d", + "ModelSource": { + "S3": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b-instruct/artifacts/inference-prepack/v1.1.0/" + } + }, + "OptimizationEnvironment": { + "ENDPOINT_SERVER_TIMEOUT": "3600", + "HF_MODEL_ID": "/opt/ml/model", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "OPTION_DTYPE": "fp16", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "SAGEMAKER_PROGRAM": "inference.py", + }, + "DeploymentInstanceType": "ml.inf2.48xlarge", + "OptimizationConfigs": [ + { + "ModelCompilationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-neuronx-sdk2.18.2", + "OverrideEnvironment": { + "OPTION_DTYPE": "fp16", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, + } + } + ], + "OutputConfig": { + "S3OutputLocation": "s3://dont-delete-ss-jarvis-integ-test-312206380606-us-west-2/" + "code/a75a061aba764f2aa014042bcdc1464b/" + }, + "OptimizationOutput": { + "RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "djl-inference:0.28.0-neuronx-sdk2.18.2" + }, + "RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20230707T131628", + "StoppingCondition": {"MaxRuntimeInSeconds": 36000}, + "ResponseMetadata": { + "RequestId": "704c7bcd-41e2-4d73-8039-262ff6a3f38b", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "704c7bcd-41e2-4d73-8039-262ff6a3f38b", + "content-type": "application/x-amz-json-1.1", + "content-length": "1787", + "date": "Thu, 04 Jul 2024 16:55:50 GMT", + }, + "RetryAttempts": 0, + }, +} + class TestJumpStartBuilder(unittest.TestCase): + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test__build_for_jumpstart_value_error( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/invalid", + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + ) + + mock_pre_trained_model.return_value.image_uri = mock_invalid_image_uri + + self.assertRaises( + ValueError, + lambda: builder.build(), + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_mms_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test__build_for_mms_jumpstart( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_mms, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + ) + + mock_pre_trained_model.return_value.image_uri = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface" + "-pytorch-inference:2.1.0-transformers4.37.0-gpu-py310-cu118" + "-ubuntu20.04" + ) + + builder.build() + builder.serve_settings.telemetry_opt_out = True + + mock_prepare_for_mms.assert_called() + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch( "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", @@ -217,7 +378,7 @@ def test_tune_for_tgi_js_local_container_sharding_not_supported( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalDeepPingException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalDeepPingException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_deep_ping_ex( self, @@ -267,7 +428,7 @@ def test_tune_for_tgi_js_local_container_deep_ping_ex( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelLoadException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelLoadException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_load_ex( self, @@ -317,7 +478,7 @@ def test_tune_for_tgi_js_local_container_load_ex( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelOutOfMemoryException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelOutOfMemoryException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_oom_ex( self, @@ -367,7 +528,7 @@ def test_tune_for_tgi_js_local_container_oom_ex( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")}, ) def test_tune_for_tgi_js_local_container_invoke_ex( self, @@ -482,7 +643,7 @@ def test_tune_for_djl_js_local_container( ) @patch( "sagemaker.serve.builder.djl_builder._serial_benchmark", - **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")} + **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")}, ) def test_tune_for_djl_js_local_container_invoke_ex( self, @@ -638,3 +799,1161 @@ def test_js_gated_model_ex( ValueError, lambda: builder.build(), ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_list_deployment_configs( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + + configs = builder.list_deployment_configs() + + self.assertEqual(configs, DEPLOYMENT_CONFIGS) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_get_deployment_config( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + expected = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value.deployment_config = expected + + self.assertEqual(builder.get_deployment_config(), expected) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_set_deployment_config( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + builder.build() + builder.set_deployment_config("config-1", "ml.g5.24xlarge") + + mock_pre_trained_model.return_value.set_deployment_config.assert_called_with( + "config-1", "ml.g5.24xlarge" + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_set_deployment_config_ex( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + self.assertRaisesRegex( + Exception, + "Cannot set deployment config to an uninitialized model.", + lambda: ModelBuilder( + model="facebook/galactica-mock-model-id", schema_builder=mock_schema_builder + ).set_deployment_config("config-2", "ml.g5.24xlarge"), + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_display_benchmark_metrics( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + + builder.list_deployment_configs() + + builder.display_benchmark_metrics() + + mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once() + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + def test_display_benchmark_metrics_initial( + self, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.side_effect = ( + lambda: DEPLOYMENT_CONFIGS + ) + + builder.display_benchmark_metrics() + + mock_pre_trained_model.return_value.display_benchmark_metrics.assert_called_once() + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_fine_tuned_model_with_fine_tuning_model_path( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri + mock_fine_tuning_model_path = "s3://test" + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species of turtle native to the brackish " + "coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a species of turtle native to the " + "brackish coastal tidal marshes of the east coast." + } + ] + builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + model_metadata={ + "FINE_TUNING_MODEL_PATH": mock_fine_tuning_model_path, + }, + ) + model = builder.build() + + model.model_data["S3DataSource"].__setitem__.assert_called_with( + "S3Uri", mock_fine_tuning_model_path + ) + mock_pre_trained_model.return_value.add_tags.assert_called_with( + {"Key": Tag.FINE_TUNING_MODEL_PATH, "Value": mock_fine_tuning_model_path} + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_fine_tuned_model_with_fine_tuning_job_name( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_serve_settings, + mock_telemetry, + ): + mock_fine_tuning_model_path = "s3://test" + mock_sagemaker_session = Mock() + mock_sagemaker_session.sagemaker_client.describe_training_job.return_value = { + "ModelArtifacts": { + "S3ModelArtifacts": mock_fine_tuning_model_path, + } + } + mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri + mock_fine_tuning_job_name = "mock-job" + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species of turtle native to the brackish " + "coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a species of turtle native to the " + "brackish coastal tidal marshes of the east coast." + } + ] + builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + model_metadata={"FINE_TUNING_JOB_NAME": mock_fine_tuning_job_name}, + sagemaker_session=mock_sagemaker_session, + ) + model = builder.build(sagemaker_session=mock_sagemaker_session) + + mock_sagemaker_session.sagemaker_client.describe_training_job.assert_called_once_with( + TrainingJobName=mock_fine_tuning_job_name + ) + + model.model_data["S3DataSource"].__setitem__.assert_any_call( + "S3Uri", mock_fine_tuning_model_path + ) + mock_pre_trained_model.return_value.add_tags.assert_called_with( + [ + {"key": Tag.FINE_TUNING_JOB_NAME, "value": mock_fine_tuning_job_name}, + {"key": Tag.FINE_TUNING_MODEL_PATH, "value": mock_fine_tuning_model_path}, + ] + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_quantize_for_jumpstart( + self, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + + mock_pysdk_model = Mock() + mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"} + mock_pysdk_model.model_data = mock_model_data + mock_pysdk_model.image_uri = mock_tgi_image_uri + mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS + mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_jumpstart( + accept_eula=True, + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + env_vars={ + "OPTION_TENSOR_PARALLEL_DEGREE": "1", + "OPTION_MAX_ROLLING_BATCH_SIZE": "2", + }, + output_path="s3://bucket/code/", + ) + + self.assertIsNotNone(out_put) + self.assertEqual( + out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder._jumpstart_speculative_decoding", + return_value=True, + ) + def test_jumpstart_model_provider_calls_jumpstart_speculative_decoding( + self, + mock_js_speculative_decoding, + mock_pretrained_js_model, + mock_is_js_model, + mock_serve_settings, + mock_capture_telemetry, + ): + mock_sagemaker_session = Mock() + mock_pysdk_model = Mock() + mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"} + mock_pysdk_model.model_data = mock_model_data + mock_pysdk_model.image_uri = mock_tgi_image_uri + mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS + mock_pysdk_model.deployment_config = OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL + mock_pysdk_model.additional_model_data_sources = CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + model_builder.pysdk_model = mock_pysdk_model + + model_builder._optimize_for_jumpstart( + accept_eula=True, + speculative_decoding_config={ + "ModelProvider": "JumpStart", + "ModelID": "meta-textgeneration-llama-3-2-1b", + "AcceptEula": False, + }, + ) + + mock_js_speculative_decoding.assert_called_once() + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_quantize_and_compile_for_jumpstart( + self, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_metadata_config = Mock() + mock_metadata_config.resolved_config = { + "supported_inference_instance_types": ["ml.inf2.48xlarge"], + "hosting_neuron_model_id": "huggingface-llmneuron-mistral-7b", + } + + mock_pysdk_model = Mock() + mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"} + mock_pysdk_model.model_data = mock_model_data + mock_pysdk_model.image_uri = mock_tgi_image_uri + mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS + mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pysdk_model.config_name = "config_name" + mock_pysdk_model._metadata_configs = {"config_name": mock_metadata_config} + mock_pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_jumpstart( + accept_eula=True, + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + output_path="s3://bucket/code/", + ) + + self.assertIsNotNone(out_put) + self.assertIsNone(out_put["OptimizationConfigs"][1]["ModelCompilationConfig"].get("Image")) + self.assertIsNone(out_put["OptimizationConfigs"][0]["ModelQuantizationConfig"].get("Image")) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model") + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_optimize_compile_for_jumpstart_without_neuron_env( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_sagemaker_session.wait_for_optimization_job.side_effect = ( + lambda *args: mock_optimization_job_response + ) + + mock_pre_trained_model.return_value = MagicMock() + mock_pre_trained_model.return_value.env = dict() + mock_pre_trained_model.return_value.model_data = mock_model_data + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.return_value = ( + DEPLOYMENT_CONFIGS + ) + mock_pre_trained_model.return_value.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value._metadata_configs = None + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + optimized_model = model_builder.optimize( + accept_eula=True, + instance_type="ml.inf2.48xlarge", + compilation_config={ + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + }, + output_path="s3://bucket/code/", + ) + + self.assertEqual( + optimized_model.image_uri, + mock_optimization_job_response["OptimizationOutput"]["RecommendedInferenceImage"], + ) + self.assertEqual( + optimized_model.model_data["S3DataSource"]["S3Uri"], + mock_optimization_job_response["OutputConfig"]["S3OutputLocation"], + ) + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model") + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_optimize_compile_for_jumpstart_with_neuron_env( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_metadata_config = Mock() + mock_sagemaker_session.wait_for_optimization_job.side_effect = ( + lambda *args: mock_optimization_job_response + ) + + mock_metadata_config.resolved_config = { + "supported_inference_instance_types": ["ml.inf2.48xlarge"], + "hosting_neuron_model_id": "neuron_model_id", + } + + mock_js_model.return_value = MagicMock() + mock_js_model.return_value.env = dict() + + mock_pre_trained_model.return_value = MagicMock() + mock_pre_trained_model.return_value.env = dict() + mock_pre_trained_model.return_value.config_name = "config_name" + mock_pre_trained_model.return_value.model_data = mock_model_data + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.return_value = ( + DEPLOYMENT_CONFIGS + ) + mock_pre_trained_model.return_value.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value._metadata_configs = { + "config_name": mock_metadata_config + } + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + optimized_model = model_builder.optimize( + accept_eula=True, + instance_type="ml.inf2.48xlarge", + compilation_config={ + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + }, + output_path="s3://bucket/code/", + ) + + self.assertEqual( + optimized_model.image_uri, + mock_optimization_job_response["OptimizationOutput"]["RecommendedInferenceImage"], + ) + self.assertEqual( + optimized_model.model_data["S3DataSource"]["S3Uri"], + mock_optimization_job_response["OutputConfig"]["S3OutputLocation"], + ) + self.assertEqual(optimized_model.env["OPTION_TENSOR_PARALLEL_DEGREE"], "2") + self.assertEqual(optimized_model.env["OPTION_N_POSITIONS"], "2048") + self.assertEqual(optimized_model.env["OPTION_DTYPE"], "fp16") + self.assertEqual(optimized_model.env["OPTION_ROLLING_BATCH"], "auto") + self.assertEqual(optimized_model.env["OPTION_MAX_ROLLING_BATCH_SIZE"], "4") + self.assertEqual(optimized_model.env["OPTION_NEURON_OPTIMIZE_LEVEL"], "2") + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model") + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_optimize_compile_for_jumpstart_without_compilation_config( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = Mock() + mock_metadata_config = Mock() + mock_sagemaker_session.wait_for_optimization_job.side_effect = ( + lambda *args: mock_optimization_job_response + ) + + mock_metadata_config.resolved_config = { + "supported_inference_instance_types": ["ml.inf2.48xlarge"], + "hosting_neuron_model_id": "huggingface-llmneuron-mistral-7b", + } + + mock_js_model.return_value = MagicMock() + mock_js_model.return_value.env = { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + } + + mock_pre_trained_model.return_value = MagicMock() + mock_pre_trained_model.return_value.env = dict() + mock_pre_trained_model.return_value.config_name = "config_name" + mock_pre_trained_model.return_value.model_data = mock_model_data + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + mock_pre_trained_model.return_value.list_deployment_configs.return_value = ( + DEPLOYMENT_CONFIGS + ) + mock_pre_trained_model.return_value.deployment_config = DEPLOYMENT_CONFIGS[0] + mock_pre_trained_model.return_value._metadata_configs = { + "config_name": mock_metadata_config + } + + sample_input = { + "inputs": "The diamondback terrapin or simply terrapin is a species " + "of turtle native to the brackish coastal tidal marshes of the", + "parameters": {"max_new_tokens": 1024}, + } + sample_output = [ + { + "generated_text": "The diamondback terrapin or simply terrapin is a " + "species of turtle native to the brackish coastal " + "tidal marshes of the east coast." + } + ] + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + schema_builder=SchemaBuilder(sample_input, sample_output), + sagemaker_session=mock_sagemaker_session, + ) + + optimized_model = model_builder.optimize( + accept_eula=True, + instance_type="ml.inf2.24xlarge", + output_path="s3://bucket/code/", + ) + + self.assertEqual( + optimized_model.image_uri, + mock_optimization_job_response["OptimizationOutput"]["RecommendedInferenceImage"], + ) + self.assertEqual( + optimized_model.model_data["S3DataSource"]["S3Uri"], + mock_optimization_job_response["OutputConfig"]["S3OutputLocation"], + ) + self.assertEqual(optimized_model.env["SAGEMAKER_PROGRAM"], "inference.py") + self.assertEqual(optimized_model.env["ENDPOINT_SERVER_TIMEOUT"], "3600") + self.assertEqual(optimized_model.env["MODEL_CACHE_ROOT"], "/opt/ml/model") + self.assertEqual(optimized_model.env["SAGEMAKER_ENV"], "1") + self.assertEqual(optimized_model.env["HF_MODEL_ID"], "/opt/ml/model") + self.assertEqual(optimized_model.env["SAGEMAKER_MODEL_SERVER_WORKERS"], "1") + + +class TestJumpStartModelBuilderOptimizationUseCases(unittest.TestCase): + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_fine_tuned_model", + return_value=False, + ) + def test_optimize_on_js_model_should_ignore_pre_optimized_configurations( + self, + mock_is_fine_tuned, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = MagicMock() + mock_sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() + mock_sagemaker_session.wait_for_optimization_job.side_effect = ( + lambda *args: mock_optimization_job_response + ) + + mock_lmi_js_model = MagicMock() + mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + mock_lmi_js_model.env = { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "OPTION_ENFORCE_EAGER": "true", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "OPTION_TENSOR_PARALLEL_DEGREE": "8", + } + + mock_js_model.return_value = mock_lmi_js_model + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-1-70b-instruct", + schema_builder=SchemaBuilder("test", "test"), + sagemaker_session=mock_sagemaker_session, + ) + + optimized_model = model_builder.optimize( + accept_eula=True, + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": { + "OPTION_QUANTIZE": "fp8", + "OPTION_TENSOR_PARALLEL_DEGREE": "4", + }, + }, + output_path="s3://bucket/code/", + ) + + assert ( + mock_sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelQuantizationConfig"]["Image"] + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124" + ) + + assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == { + "instance_type": "ml.g5.24xlarge", + "config_name": "lmi", + } + assert optimized_model.env == { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "OPTION_ENFORCE_EAGER": "true", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "OPTION_TENSOR_PARALLEL_DEGREE": "4", # should be overridden from 8 to 4 + "OPTION_QUANTIZE": "fp8", # should be added to the env + } + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_fine_tuned_model", + return_value=False, + ) + def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_override( + self, + mock_is_fine_tuned, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + mock_sagemaker_session = MagicMock() + mock_sagemaker_session.sagemaker_client.create_optimization_job = MagicMock() + mock_sagemaker_session.wait_for_optimization_job.side_effect = ( + lambda *args: mock_optimization_job_response + ) + + mock_lmi_js_model = MagicMock() + mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi27.0.0-cu124" + } + mock_lmi_js_model.env = { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "OPTION_ENFORCE_EAGER": "true", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "OPTION_TENSOR_PARALLEL_DEGREE": "8", + } + + mock_js_model.return_value = mock_lmi_js_model + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-1-70b-instruct", + schema_builder=SchemaBuilder("test", "test"), + sagemaker_session=mock_sagemaker_session, + ) + + optimized_model = model_builder.optimize( + accept_eula=True, + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": { + "OPTION_QUANTIZE": "fp8", + }, + }, + output_path="s3://bucket/code/", + ) + + assert ( + mock_sagemaker_session.sagemaker_client.create_optimization_job.call_args_list[0][1][ + "OptimizationConfigs" + ][0]["ModelQuantizationConfig"]["Image"] + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi27.0.0-cu124" + ) + + assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == { + "instance_type": "ml.g5.24xlarge", + "config_name": "lmi", + } + assert optimized_model.env == { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "OPTION_ENFORCE_EAGER": "true", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "OPTION_TENSOR_PARALLEL_DEGREE": "8", + "OPTION_QUANTIZE": "fp8", # should be added to the env + } + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", + return_value=True, + ) + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStartModel") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_fine_tuned_model", + return_value=False, + ) + def test_optimize_on_js_model_test_image_defaulting_scenarios( + self, + mock_is_fine_tuned, + mock_is_jumpstart_model, + mock_js_model, + mock_is_gated_model, + mock_serve_settings, + mock_telemetry, + ): + + mock_lmi_js_model = MagicMock() + mock_lmi_js_model.image_uri = mock_djl_image_uri + mock_lmi_js_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-1-70b-instruct", + schema_builder=SchemaBuilder("test", "test"), + sagemaker_session=MagicMock(), + ) + model_builder.pysdk_model = mock_lmi_js_model + + # assert lmi version is upgraded to hardcoded default + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + ) + + # assert lmi version is left as is + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi21.0.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi21.0.0-cu124", + ) + + # assert lmi version is upgraded to the highest provided version + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelShardingConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + }, + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124" + } + }, + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelShardingConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + self.assertEqual( + optimization_args["OptimizationConfigs"][1]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + + # assert lmi version is upgraded to the highest provided version and sets empty image config + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124" + } + }, + {"ModelShardingConfig": {}}, + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + self.assertEqual( + optimization_args["OptimizationConfigs"][1]["ModelShardingConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi30.0.0-cu124", + ) + + # assert lmi version is left as is on minor version bump + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.1.0-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.1.0-cu124", + ) + + # assert lmi version is left as is on patch version bump + optimization_args = model_builder._set_optimization_image_default( + { + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124" + } + } + ] + } + ) + + self.assertEqual( + optimization_args["OptimizationConfigs"][0]["ModelQuantizationConfig"]["Image"], + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi13.0.1-cu124", + ) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 1d199b7401..8ae6072ee5 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -11,17 +11,35 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from unittest.mock import MagicMock, patch, Mock, mock_open + +from unittest.mock import MagicMock, patch, Mock, mock_open, ANY import unittest from pathlib import Path from copy import deepcopy +import deepdiff +import pytest +from sagemaker.enums import EndpointType + +from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig +from sagemaker.batch_inference.batch_transform_inference_config import BatchTransformInferenceConfig + +from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements + +from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig + +from sagemaker.model import Model + +from sagemaker.serve import SchemaBuilder from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.model_format.mlflow.constants import MLFLOW_TRACKING_ARN from sagemaker.serve.utils import task from sagemaker.serve.utils.exceptions import TaskNotFoundException +from sagemaker.serve.utils.predictors import TensorflowServingLocalPredictor from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.validations.optimization import _validate_optimization_configuration from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG schema_builder = MagicMock() @@ -43,39 +61,41 @@ mock_image_uri = "abcd/efghijk" mock_1p_dlc_image_uri = "763104351884.dkr.ecr.us-east-1.amazonaws.com" -mock_role_arn = "sample role arn" +mock_role_arn = "arn:aws:iam::123456789012:role/SageMakerRole" mock_s3_model_data_url = "sample s3 data url" mock_secret_key = "mock_secret_key" mock_instance_type = "mock instance type" -supported_model_server = { +supported_model_servers = { ModelServer.TORCHSERVE, ModelServer.TRITON, ModelServer.DJL_SERVING, + ModelServer.TENSORFLOW_SERVING, + ModelServer.MMS, + ModelServer.TGI, + ModelServer.TEI, + ModelServer.SMD, } mock_session = MagicMock() +RESOURCE_REQUIREMENTS = ResourceRequirements( + requests={ + "num_cpus": 0.5, + "memory": 512, + "copies": 2, + }, + limits={}, +) -class TestModelBuilder(unittest.TestCase): - @patch("sagemaker.serve.builder.model_builder._ServeSettings") - def test_validation_in_progress_mode_not_supported(self, mock_serveSettings): - builder = ModelBuilder() - self.assertRaisesRegex( - Exception, - "IN_PROCESS mode is not supported yet!", - builder.build, - Mode.IN_PROCESS, - mock_role_arn, - mock_session, - ) +class TestModelBuilder(unittest.TestCase): @patch("sagemaker.serve.builder.model_builder._ServeSettings") def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSettings): builder = ModelBuilder(inference_spec="some value", model=Mock(spec=object)) self.assertRaisesRegex( Exception, - "Cannot have both the Model and Inference spec in the builder", + "Can only set one of the following: model, inference_spec.", builder.build, Mode.SAGEMAKER_ENDPOINT, mock_role_arn, @@ -88,7 +108,7 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings): self.assertRaisesRegex( Exception, "%s is not supported yet! Supported model servers: %s" - % (builder.model_server, supported_model_server), + % (builder.model_server, supported_model_servers), builder.build, Mode.SAGEMAKER_ENDPOINT, mock_role_arn, @@ -101,7 +121,7 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings self.assertRaisesRegex( Exception, "Model_server must be set when non-first-party image_uri is set. " - + "Supported model servers: %s" % supported_model_server, + + "Supported model servers: %s" % supported_model_servers, builder.build, Mode.SAGEMAKER_ENDPOINT, mock_role_arn, @@ -122,7 +142,125 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set mock_session, ) + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl") + def test_model_server_override_djl_with_model(self, mock_build_for_djl, mock_serve_settings): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.DJL_SERVING, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_djl.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_settings): + builder = ModelBuilder( + model_server=ModelServer.DJL_SERVING, model=None, inference_spec=None + ) + self.assertRaisesRegex( + Exception, + "Missing required parameter `model` or 'ml_flow' path", + builder.build, + Mode.SAGEMAKER_ENDPOINT, + mock_role_arn, + mock_session, + ) + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve") + def test_model_server_override_torchserve_with_model( + self, mock_build_for_ts, mock_serve_settings + ): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings): + builder = ModelBuilder(model_server=ModelServer.TORCHSERVE) + self.assertRaisesRegex( + Exception, + "Missing required parameter `model` or 'ml_flow' path", + builder.build, + Mode.SAGEMAKER_ENDPOINT, + mock_role_arn, + mock_session, + ) + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton") + def test_model_server_override_triton_with_model(self, mock_build_for_ts, mock_serve_settings): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.TRITON, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving") + def test_model_server_override_tensor_with_model(self, mock_build_for_ts, mock_serve_settings): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.TENSORFLOW_SERVING, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei") + def test_model_server_override_tei_with_model(self, mock_build_for_ts, mock_serve_settings): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.TEI, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi") + def test_model_server_override_tgi_with_model(self, mock_build_for_ts, mock_serve_settings): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.TGI, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") + def test_model_server_override_transformers_with_model( + self, mock_build_for_ts, mock_serve_settings + ): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder = ModelBuilder(model_server=ModelServer.MMS, model="gpt_llm_burt") + builder.build(sagemaker_session=mock_session) + + mock_build_for_ts.assert_called_once() + @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -141,6 +279,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -177,7 +316,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -191,7 +330,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( ) mock_model_obj = Mock() - mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501 mock_model_obj if image_uri == mock_image_uri and image_config == MOCK_IMAGE_CONFIG @@ -200,6 +339,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( and role == mock_role_arn and env == ENV_VARS and sagemaker_session == mock_session + and "model-name-" in name else None ) @@ -227,6 +367,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc( self.assertEqual(build_result.serve_settings, mock_setting_object) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -245,6 +389,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc( mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -280,7 +425,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -294,13 +439,14 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc( ) mock_model_obj = Mock() - mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501 mock_model_obj if image_uri == mock_1p_dlc_image_uri and model_data == model_data and role == mock_role_arn and env == ENV_VARS and sagemaker_session == mock_session + and "model-name-" in name else None ) @@ -387,7 +533,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec( if inference_spec == mock_inference_spec and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -401,13 +547,14 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec( ) mock_model_obj = Mock() - mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501 mock_model_obj if image_uri == mock_image_uri and model_data == model_data and role == mock_role_arn and env == ENV_VARS_INF_SPEC and sagemaker_session == mock_session + and "model-name-" in name else None ) @@ -429,6 +576,10 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec( self.assertEqual(build_result.serve_settings, mock_setting_object) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -447,6 +598,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -483,7 +635,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -497,13 +649,14 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( ) mock_model_obj = Mock() - mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501 mock_model_obj if image_uri == mock_image_uri and model_data == model_data and role == mock_role_arn and env == ENV_VARS and sagemaker_session == mock_session + and "model-name-" in name else None ) @@ -531,6 +684,10 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model( self.assertEqual("sample agent ModelBuilder", user_agent) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder.save_xgboost") @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @@ -551,6 +708,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model( mock_prepare_for_torchserve, mock_detect_fw_version, mock_save_xgb, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_container.side_effect = lambda model, region, instance_type: ( @@ -587,7 +745,7 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model( mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -601,13 +759,14 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model( ) mock_model_obj = Mock() - mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501 mock_model_obj if image_uri == mock_image_uri and model_data == model_data and role == mock_role_arn and env == ENV_VARS and sagemaker_session == mock_session + and "model-name-" in name else None ) @@ -706,13 +865,14 @@ def test_build_happy_path_with_local_container_mode( mock_mode.prepare.side_effect = lambda: None mock_model_obj = Mock() - mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501 mock_model_obj if image_uri == mock_image_uri and model_data is None and role == mock_role_arn and env == {} and sagemaker_session == mock_session + and "model-name-" in name else None ) @@ -813,7 +973,7 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo if inference_spec == mock_inference_spec and model_server == ModelServer.TORCHSERVE else None ) - mock_sagemaker_endpoint_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_sagemaker_endpoint_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -827,13 +987,14 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo ) mock_model_obj = Mock() - mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501 mock_model_obj if image_uri == mock_image_uri and model_data is None and role == mock_role_arn and env == {} and sagemaker_session == mock_session + and "model-name-" in name else None ) @@ -879,6 +1040,10 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo ) @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -899,6 +1064,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_is_jumpstart_model_id, ): # setup mocks mock_detect_fw_version.return_value = framework, version @@ -935,7 +1101,7 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None ) - mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart: ( # noqa E501 + mock_mode.prepare.side_effect = lambda model_path, secret_key, s3_model_data_url, sagemaker_session, image_uri, jumpstart, **kwargs: ( # noqa E501 ( model_data, ENV_VAR_PAIR, @@ -973,13 +1139,14 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co ) mock_model_obj = Mock() - mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls: ( # noqa E501 + mock_sdk_model.side_effect = lambda image_uri, image_config, vpc_config, model_data, role, env, sagemaker_session, predictor_cls, name: ( # noqa E501 mock_model_obj if image_uri == mock_image_uri and model_data == model_data and role == mock_role_arn and env == ENV_VARS and sagemaker_session == mock_session + and "model-name-" in name else None ) @@ -1006,8 +1173,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1052,8 +1219,8 @@ def test_build_happy_path_when_schema_builder_not_present( @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1098,8 +1265,8 @@ def test_build_negative_path_when_schema_builder_not_present( @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1136,51 +1303,11 @@ def test_build_can_fit_on_single_gpu( mock_can_fit_on_single_gpu.assert_called_once() - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl") - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") - @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") - @patch("sagemaker.huggingface.llm_utils.urllib") - @patch("sagemaker.huggingface.llm_utils.json") - @patch("sagemaker.model_uris.retrieve") - @patch("sagemaker.serve.builder.model_builder._ServeSettings") - def test_build_is_deepspeed_model( - self, - mock_serveSettings, - mock_model_uris_retrieve, - mock_llm_utils_json, - mock_llm_utils_urllib, - mock_model_json, - mock_model_urllib, - mock_image_uris_retrieve, - mock_can_fit_on_single_gpu, - mock_build_for_djl, - ): - mock_setting_object = mock_serveSettings.return_value - mock_setting_object.role_arn = mock_role_arn - mock_setting_object.s3_model_data_url = mock_s3_model_data_url - - mock_model_uris_retrieve.side_effect = KeyError - mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} - mock_llm_utils_urllib.request.Request.side_effect = Mock() - - mock_model_json.load.return_value = {"some": "config"} - mock_model_urllib.request.Request.side_effect = Mock() - - mock_image_uris_retrieve.return_value = "https://some-image-uri" - mock_can_fit_on_single_gpu.return_value = False - - model_builder = ModelBuilder(model="stable-diffusion") - model_builder.build(sagemaker_session=mock_session) - - mock_build_for_djl.assert_called_once() - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1220,8 +1347,8 @@ def test_build_for_transformers_happy_case( @patch("sagemaker.serve.builder.model_builder.ModelBuilder._try_fetch_gpu_info") @patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1259,12 +1386,12 @@ def test_build_for_transformers_happy_case_with_values( mock_build_for_transformers.assert_called_once() - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl", Mock()) + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) @patch("sagemaker.serve.builder.model_builder._get_gpu_info") @patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1308,8 +1435,8 @@ def test_build_for_transformers_happy_case_with_valid_gpu_info( @patch("sagemaker.serve.builder.model_builder._get_gpu_info_fallback") @patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1355,16 +1482,16 @@ def test_build_for_transformers_happy_case_with_valid_gpu_fallback( ) self.assertEqual(model_builder._can_fit_on_single_gpu(), True) - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @patch("sagemaker.serve.builder.model_builder._ServeSettings") - def test_build_is_fast_transformers_model( + def test_build_fallback_to_transformers( self, mock_serveSettings, mock_model_uris_retrieve, @@ -1374,7 +1501,7 @@ def test_build_is_fast_transformers_model( mock_model_urllib, mock_image_uris_retrieve, mock_can_fit_on_single_gpu, - mock_build_for_djl, + mock_build_for_transformers, ): mock_setting_object = mock_serveSettings.return_value mock_setting_object.role_arn = mock_role_arn @@ -1386,25 +1513,25 @@ def test_build_is_fast_transformers_model( mock_model_json.load.return_value = {"some": "config"} mock_model_urllib.request.Request.side_effect = Mock() + mock_build_for_transformers.side_effect = Mock() mock_image_uris_retrieve.return_value = "https://some-image-uri" mock_can_fit_on_single_gpu.return_value = False - model_builder = ModelBuilder(model="gpt_neo") + model_builder = ModelBuilder(model="gpt_llm_burt") model_builder.build(sagemaker_session=mock_session) - mock_build_for_djl.assert_called_once() + mock_build_for_transformers.assert_called_once() - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @patch("sagemaker.serve.builder.model_builder._ServeSettings") - def test_build_fallback_to_transformers( + def test_text_generation( self, mock_serveSettings, mock_model_uris_retrieve, @@ -1413,38 +1540,36 @@ def test_build_fallback_to_transformers( mock_model_json, mock_model_urllib, mock_image_uris_retrieve, - mock_can_fit_on_single_gpu, - mock_build_for_transformers, + mock_build_for_tgi, ): mock_setting_object = mock_serveSettings.return_value mock_setting_object.role_arn = mock_role_arn mock_setting_object.s3_model_data_url = mock_s3_model_data_url mock_model_uris_retrieve.side_effect = KeyError - mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"} + mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-generation"} mock_llm_utils_urllib.request.Request.side_effect = Mock() mock_model_json.load.return_value = {"some": "config"} mock_model_urllib.request.Request.side_effect = Mock() - mock_build_for_transformers.side_effect = Mock() + mock_build_for_tgi.side_effect = Mock() mock_image_uris_retrieve.return_value = "https://some-image-uri" - mock_can_fit_on_single_gpu.return_value = False - model_builder = ModelBuilder(model="gpt_llm_burt") + model_builder = ModelBuilder(model="bloom-560m") model_builder.build(sagemaker_session=mock_session) - mock_build_for_transformers.assert_called_once() + mock_build_for_tgi.assert_called_once() - @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @patch("sagemaker.serve.builder.model_builder._ServeSettings") - def test_text_generation( + def test_sentence_similarity( self, mock_serveSettings, mock_model_uris_retrieve, @@ -1453,32 +1578,32 @@ def test_text_generation( mock_model_json, mock_model_urllib, mock_image_uris_retrieve, - mock_build_for_tgi, + mock_build_for_tei, ): mock_setting_object = mock_serveSettings.return_value mock_setting_object.role_arn = mock_role_arn mock_setting_object.s3_model_data_url = mock_s3_model_data_url mock_model_uris_retrieve.side_effect = KeyError - mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-generation"} + mock_llm_utils_json.load.return_value = {"pipeline_tag": "sentence-similarity"} mock_llm_utils_urllib.request.Request.side_effect = Mock() mock_model_json.load.return_value = {"some": "config"} mock_model_urllib.request.Request.side_effect = Mock() - mock_build_for_tgi.side_effect = Mock() + mock_build_for_tei.side_effect = Mock() mock_image_uris_retrieve.return_value = "https://some-image-uri" - model_builder = ModelBuilder(model="bloom-560m") + model_builder = ModelBuilder(model="bloom-560m", schema_builder=schema_builder) model_builder.build(sagemaker_session=mock_session) - mock_build_for_tgi.assert_called_once() + mock_build_for_tei.assert_called_once() @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) @patch("sagemaker.serve.builder.model_builder.ModelBuilder._try_fetch_gpu_info") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1516,8 +1641,8 @@ def test_try_fetch_gpu_info_throws( @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock()) @patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1554,8 +1679,8 @@ def test_total_inference_model_size_mib_throws( @patch("sagemaker.serve.builder.tgi_builder.HuggingFaceModel") @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1601,8 +1726,8 @@ def test_build_happy_path_override_with_task_provided( self.assertEqual(sample_outputs, model_builder.schema_builder.sample_output) @patch("sagemaker.image_uris.retrieve") - @patch("sagemaker.djl_inference.model.urllib") - @patch("sagemaker.djl_inference.model.json") + @patch("sagemaker.serve.utils.hf_utils.urllib") + @patch("sagemaker.serve.utils.hf_utils.json") @patch("sagemaker.huggingface.llm_utils.urllib") @patch("sagemaker.huggingface.llm_utils.json") @patch("sagemaker.model_uris.retrieve") @@ -1677,6 +1802,7 @@ def test_build_task_override_with_invalid_model_provided( model_builder.build(sagemaker_session=mock_session) @patch("os.makedirs", Mock()) + @patch("sagemaker.serve.builder.model_builder._maintain_lineage_tracking_for_mlflow_model") @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") @patch("sagemaker.serve.builder.model_builder.save_pkl") @@ -1705,6 +1831,7 @@ def test_build_mlflow_model_local_input_happy( mock_save_pkl, mock_prepare_for_torchserve, mock_detect_fw_version, + mock_lineage_tracking, ): # setup mocks @@ -1750,6 +1877,85 @@ def test_build_mlflow_model_local_input_happy( self.assertEqual(build_result.serve_settings, mock_setting_object) self.assertEqual(builder.env_vars["MLFLOW_MODEL_FLAVOR"], "sklearn") + build_result.deploy( + initial_instance_count=1, instance_type=mock_instance_type, mode=Mode.SAGEMAKER_ENDPOINT + ) + mock_lineage_tracking.assert_called_once() + + @patch("os.makedirs", Mock()) + @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") + @patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve") + @patch("sagemaker.serve.builder.model_builder.save_pkl") + @patch("sagemaker.serve.builder.model_builder._copy_directory_contents") + @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path") + @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata") + @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode") + @patch("sagemaker.serve.builder.model_builder.Model") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("os.path.isfile", return_value=True) + @patch("os.path.exists") + def test_build_mlflow_model_local_input_happy_flavor_server_mismatch( + self, + mock_path_exists, + mock_is_file, + mock_open, + mock_sdk_model, + mock_sageMakerEndpointMode, + mock_serveSettings, + mock_detect_container, + mock_get_all_flavor_metadata, + mock_generate_mlflow_artifact_path, + mock_copy_directory_contents, + mock_save_pkl, + mock_prepare_for_torchserve, + mock_detect_fw_version, + ): + # setup mocks + + mock_detect_container.return_value = mock_image_uri + mock_get_all_flavor_metadata.return_value = {"sklearn": "some_data"} + mock_generate_mlflow_artifact_path.return_value = "some_path" + + mock_prepare_for_torchserve.return_value = mock_secret_key + + # Mock _ServeSettings + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_path_exists.side_effect = lambda path: True if path == "test_path" else False + + mock_mode = Mock() + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode if inference_spec is None and model_server == ModelServer.TORCHSERVE else None + ) + mock_mode.prepare.return_value = ( + model_data, + ENV_VAR_PAIR, + ) + + updated_env_var = deepcopy(ENV_VARS) + updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "sklearn"}) + mock_model_obj = Mock() + mock_sdk_model.return_value = mock_model_obj + + mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent" + + # run + builder = ModelBuilder( + schema_builder=schema_builder, + model_metadata={"MLFLOW_MODEL_PATH": MODEL_PATH}, + model_server=ModelServer.TENSORFLOW_SERVING, + ) + with self.assertRaises(ValueError): + builder.build( + Mode.SAGEMAKER_ENDPOINT, + mock_role_arn, + mock_session, + ) + @patch("os.makedirs", Mock()) @patch("sagemaker.serve.builder.model_builder.S3Downloader.list") @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") @@ -1899,3 +2105,2175 @@ def test_build_mlflow_model_s3_input_non_mlflow_case( mock_role_arn, mock_session, ) + + @patch("os.makedirs", Mock()) + @patch("sagemaker.serve.builder.model_builder._maintain_lineage_tracking_for_mlflow_model") + @patch("sagemaker.serve.builder.tf_serving_builder.prepare_for_tf_serving") + @patch("sagemaker.serve.builder.model_builder.S3Downloader.list") + @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") + @patch("sagemaker.serve.builder.tf_serving_builder.save_pkl") + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path") + @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata") + @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode") + @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("os.path.exists") + def test_build_mlflow_model_s3_input_tensorflow_serving_happy( + self, + mock_path_exists, + mock_open, + mock_sdk_model, + mock_sageMakerEndpointMode, + mock_serveSettings, + mock_detect_container, + mock_get_all_flavor_metadata, + mock_generate_mlflow_artifact_path, + mock_download_s3_artifacts, + mock_save_pkl, + mock_detect_fw_version, + mock_s3_downloader, + mock_prepare_for_tf_serving, + mock_lineage_tracking, + ): + # setup mocks + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + mock_detect_container.return_value = mock_image_uri + mock_get_all_flavor_metadata.return_value = {"tensorflow": "some_data"} + mock_generate_mlflow_artifact_path.return_value = "some_path" + + mock_prepare_for_tf_serving.return_value = mock_secret_key + + # Mock _ServeSettings + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_path_exists.side_effect = lambda path: True if path == "test_path" else False + + mock_mode = Mock() + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode + if inference_spec is None and model_server == ModelServer.TENSORFLOW_SERVING + else None + ) + mock_mode.prepare.return_value = ( + model_data, + ENV_VAR_PAIR, + ) + + updated_env_var = deepcopy(ENV_VARS) + updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "tensorflow"}) + mock_model_obj = Mock() + mock_sdk_model.return_value = mock_model_obj + + mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent" + + # run + builder = ModelBuilder( + schema_builder=schema_builder, model_metadata={"MLFLOW_MODEL_PATH": "s3://test_path/"} + ) + build_result = builder.build(sagemaker_session=mock_session) + self.assertEqual(mock_model_obj, build_result) + self.assertEqual(build_result.mode, Mode.SAGEMAKER_ENDPOINT) + self.assertEqual(build_result.modes, {str(Mode.SAGEMAKER_ENDPOINT): mock_mode}) + self.assertEqual(build_result.serve_settings, mock_setting_object) + self.assertEqual(builder.env_vars["MLFLOW_MODEL_FLAVOR"], "tensorflow") + + build_result.deploy( + initial_instance_count=1, instance_type=mock_instance_type, mode=Mode.SAGEMAKER_ENDPOINT + ) + mock_lineage_tracking.assert_called_once() + + @patch("os.makedirs", Mock()) + @patch("sagemaker.serve.builder.tf_serving_builder.prepare_for_tf_serving") + @patch("sagemaker.serve.builder.model_builder.S3Downloader.list") + @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") + @patch("sagemaker.serve.builder.tf_serving_builder.save_pkl") + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path") + @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata") + @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.LocalContainerMode") + @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("os.path.exists") + def test_build_mlflow_model_s3_input_tensorflow_serving_local_mode_happy( + self, + mock_path_exists, + mock_open, + mock_sdk_model, + mock_local_container_mode, + mock_serveSettings, + mock_detect_container, + mock_get_all_flavor_metadata, + mock_generate_mlflow_artifact_path, + mock_download_s3_artifacts, + mock_save_pkl, + mock_detect_fw_version, + mock_s3_downloader, + mock_prepare_for_tf_serving, + ): + # setup mocks + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + mock_detect_container.return_value = mock_image_uri + mock_get_all_flavor_metadata.return_value = {"tensorflow": "some_data"} + mock_generate_mlflow_artifact_path.return_value = "some_path" + + mock_prepare_for_tf_serving.return_value = mock_secret_key + + # Mock _ServeSettings + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_path_exists.side_effect = lambda path: True if path == "test_path" else False + + mock_mode = Mock() + mock_mode.prepare.side_effect = lambda: None + mock_local_container_mode.return_value = mock_mode + mock_mode.prepare.return_value = ( + model_data, + ENV_VAR_PAIR, + ) + + updated_env_var = deepcopy(ENV_VARS) + updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "tensorflow"}) + mock_model_obj = Mock() + mock_sdk_model.return_value = mock_model_obj + + mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent" + + # run + builder = ModelBuilder( + mode=Mode.LOCAL_CONTAINER, + schema_builder=schema_builder, + model_metadata={"MLFLOW_MODEL_PATH": "s3://test_path/"}, + ) + build_result = builder.build(sagemaker_session=mock_session) + self.assertEqual(mock_model_obj, build_result) + self.assertEqual(build_result.mode, Mode.LOCAL_CONTAINER) + self.assertEqual(build_result.modes, {str(Mode.LOCAL_CONTAINER): mock_mode}) + self.assertEqual(build_result.serve_settings, mock_setting_object) + self.assertEqual(builder.env_vars["MLFLOW_MODEL_FLAVOR"], "tensorflow") + + predictor = build_result.deploy() + assert isinstance(predictor, TensorflowServingLocalPredictor) + + @patch("os.makedirs", Mock()) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=False, + ) + @patch("sagemaker.serve.builder.tf_serving_builder.prepare_for_tf_serving") + @patch("sagemaker.serve.builder.model_builder.S3Downloader.list") + @patch("sagemaker.serve.builder.model_builder._detect_framework_and_version") + @patch("sagemaker.serve.builder.model_builder.save_pkl") + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._generate_mlflow_artifact_path") + @patch("sagemaker.serve.builder.model_builder._get_all_flavor_metadata") + @patch("sagemaker.serve.builder.model_builder._select_container_for_mlflow_model") + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode") + @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("os.path.exists") + def test_build_tensorflow_serving_non_mlflow_case( + self, + mock_path_exists, + mock_open, + mock_sdk_model, + mock_sageMakerEndpointMode, + mock_serveSettings, + mock_detect_container, + mock_get_all_flavor_metadata, + mock_generate_mlflow_artifact_path, + mock_download_s3_artifacts, + mock_save_pkl, + mock_detect_fw_version, + mock_s3_downloader, + mock_prepare_for_tf_serving, + mock_is_jumpstart_model_id, + ): + mock_s3_downloader.return_value = [] + mock_detect_container.return_value = mock_image_uri + mock_get_all_flavor_metadata.return_value = {"tensorflow": "some_data"} + mock_generate_mlflow_artifact_path.return_value = "some_path" + + mock_prepare_for_tf_serving.return_value = mock_secret_key + + # Mock _ServeSettings + mock_setting_object = mock_serveSettings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + mock_path_exists.side_effect = lambda path: True if path == "test_path" else False + + mock_mode = Mock() + mock_sageMakerEndpointMode.side_effect = lambda inference_spec, model_server: ( + mock_mode + if inference_spec is None and model_server == ModelServer.TENSORFLOW_SERVING + else None + ) + mock_mode.prepare.return_value = ( + model_data, + ENV_VAR_PAIR, + ) + + updated_env_var = deepcopy(ENV_VARS) + updated_env_var.update({"MLFLOW_MODEL_FLAVOR": "tensorflow"}) + mock_model_obj = Mock() + mock_sdk_model.return_value = mock_model_obj + + mock_session.sagemaker_client._user_agent_creator.to_string = lambda: "sample agent" + + # run + builder = ModelBuilder( + model=mock_fw_model, + schema_builder=schema_builder, + model_server=ModelServer.TENSORFLOW_SERVING, + ) + + self.assertRaisesRegex( + Exception, + "Tensorflow Serving is currently only supported for mlflow models.", + builder.build, + Mode.SAGEMAKER_ENDPOINT, + mock_role_arn, + mock_session, + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_djl") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize( + self, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_djl, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_djl.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + quantization_config = { + "Image": "quantization-image-uri", + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + } + env_vars = {"Var1": "value", "Var2": "value"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + quantization_config=quantization_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) + + assert mock_send_telemetry.call_count == 2 + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + DeploymentInstanceType="ml.g5.24xlarge", + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={"Var1": "value", "Var2": "value"}, + ModelSource={"S3": {"S3Uri": "s3://uri"}}, + OptimizationConfigs=[ + { + "ModelQuantizationConfig": { + "Image": "quantization-image-uri", + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + } + } + ], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + def test_handle_mlflow_input_without_mlflow_model_path(self): + builder = ModelBuilder(model_metadata={}) + assert not builder._has_mlflow_arguments() + + @patch("importlib.util.find_spec") + @patch("mlflow.set_tracking_uri") + @patch("mlflow.get_run") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_run_id( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_run, + mock_set_tracking_uri, + mock_find_spec, + ): + mock_find_spec.return_value = True + mock_run_info = Mock() + mock_run_info.info.artifact_uri = "s3://bucket/path" + mock_get_run.return_value = mock_run_info + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "runs:/runid/mlflow-path", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + } + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path/mlflow-path") + + @patch("importlib.util.find_spec") + @patch("mlflow.set_tracking_uri") + @patch("mlflow.MlflowClient.get_model_version") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_registry_path_with_model_version( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_model_version, + mock_set_tracking_uri, + mock_find_spec, + ): + mock_find_spec.return_value = True + mock_registry_path = Mock() + mock_registry_path.source = "s3://bucket/path/" + mock_get_model_version.return_value = mock_registry_path + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "models:/model-name/1", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + } + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path/") + + @patch("importlib.util.find_spec") + @patch("mlflow.set_tracking_uri") + @patch("mlflow.MlflowClient.get_model_version_by_alias") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_registry_path_with_model_alias( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_model_version_by_alias, + mock_set_tracking_uri, + mock_find_spec, + ): + mock_find_spec.return_value = True + mock_registry_path = Mock() + mock_registry_path.source = "s3://bucket/path" + mock_get_model_version_by_alias.return_value = mock_registry_path + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "models:/model-name@production", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + } + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path/") + + @patch("mlflow.MlflowClient.get_model_version") + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_registry_path_missing_tracking_server_arn( + self, + mock_validate, + mock_s3_downloader, + mock_initialize, + mock_check_mlflow_model, + mock_get_model_version, + ): + mock_registry_path = Mock() + mock_registry_path.source = "s3://bucket/path" + mock_get_model_version.return_value = mock_registry_path + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "models:/model-name/1", + } + ) + self.assertRaisesRegex( + Exception, + "%s is not provided in ModelMetadata or through set_tracking_arn " + "but MLflow model path was provided." % MLFLOW_TRACKING_ARN, + builder._handle_mlflow_input, + ) + + @patch.object(ModelBuilder, "_mlflow_metadata_exists", autospec=True) + @patch.object(ModelBuilder, "_initialize_for_mlflow", autospec=True) + @patch("sagemaker.serve.builder.model_builder._download_s3_artifacts") + @patch("sagemaker.serve.builder.model_builder._validate_input_for_mlflow") + def test_handle_mlflow_input_model_package_arn( + self, mock_validate, mock_s3_downloader, mock_initialize, mock_check_mlflow_model + ): + mock_check_mlflow_model.return_value = True + mock_s3_downloader.return_value = ["s3://some_path/MLmodel"] + mock_model_package = {"SourceUri": "s3://bucket/path"} + mock_session.sagemaker_client.describe_model_package.return_value = mock_model_package + + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "arn:aws:sagemaker:us-west-2:000000000000:model-package/test", + "MLFLOW_TRACKING_ARN": "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", + }, + sagemaker_session=mock_session, + ) + builder._handle_mlflow_input() + mock_initialize.assert_called_once_with(builder, "s3://bucket/path") + + @patch("importlib.util.find_spec", Mock(return_value=True)) + @patch("mlflow.set_tracking_uri") + def test_set_tracking_arn_success(self, mock_set_tracking_uri): + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "arn:aws:sagemaker:us-west-2:000000000000:model-package/test", + } + ) + tracking_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + builder.set_tracking_arn(tracking_arn) + mock_set_tracking_uri.assert_called_once_with(tracking_arn) + assert builder.model_metadata[MLFLOW_TRACKING_ARN] == tracking_arn + + @patch("importlib.util.find_spec", Mock(return_value=False)) + def test_set_tracking_arn_mlflow_not_installed(self): + builder = ModelBuilder( + model_metadata={ + "MLFLOW_MODEL_PATH": "arn:aws:sagemaker:us-west-2:000000000000:model-package/test", + } + ) + tracking_arn = "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + self.assertRaisesRegex( + ImportError, + "Unable to import sagemaker_mlflow, check if sagemaker_mlflow is installed", + builder.set_tracking_arn, + tracking_arn, + ) + + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_local_mode(self, mock_get_serve_setting): + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", mode=Mode.LOCAL_CONTAINER + ) + + self.assertRaisesRegex( + ValueError, + "Model optimization is only supported in Sagemaker Endpoint Mode.", + lambda: model_builder.optimize( + instance_type="ml.g5.24xlarge", + quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + ), + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_for_hf_with_both_quantization_and_compilation( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-8B-Instruc"} + + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HF_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_hf( + job_name="job_name-123", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + output_path="s3://bucket/code/", + ) + + self.assertEqual(model_builder.env_vars["HF_TOKEN"], "token") + self.assertEqual(model_builder.role_arn, "role-arn") + self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge") + self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq") + self.assertEqual(model_builder.pysdk_model.env["OPTION_TENSOR_PARALLEL_DEGREE"], "2") + self.assertEqual( + out_put, + { + "OptimizationJobName": "job_name-123", + "DeploymentInstanceType": "ml.g5.2xlarge", + "RoleArn": "role-arn", + "ModelSource": {"S3": {"S3Uri": "s3://bucket/code/code/"}}, + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"} + } + }, + { + "ModelCompilationConfig": { + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"} + } + }, + ], + "OutputConfig": {"S3OutputLocation": "s3://bucket/code/"}, + }, + ) + + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_exclusive_sharding(self, mock_get_serve_setting): + mock_sagemaker_session = Mock() + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + sagemaker_session=mock_sagemaker_session, + ) + + self.assertRaisesRegex( + ValueError, + "Optimizations that use Compilation and Sharding are not supported for GPU instances.", + lambda: model_builder.optimize( + instance_type="ml.g5.24xlarge", + quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + ), + ) + + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting): + mock_sagemaker_session = Mock() + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + sagemaker_session=mock_sagemaker_session, + ) + + self.assertRaisesRegex( + ValueError, + "OPTION_TENSOR_PARALLEL_DEGREE is a required environment variable with sharding config.", + lambda: model_builder.optimize( + instance_type="ml.g5.24xlarge", + sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + ), + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_for_hf_with_custom_s3_path( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-8B-Instruc"} + + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HF_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_hf( + job_name="job_name-123", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + output_path="s3://bucket/code/", + ) + + self.assertEqual(model_builder.env_vars["HF_TOKEN"], "token") + self.assertEqual(model_builder.role_arn, "role-arn") + self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge") + self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq") + self.assertEqual( + out_put, + { + "OptimizationJobName": "job_name-123", + "DeploymentInstanceType": "ml.g5.2xlarge", + "RoleArn": "role-arn", + "ModelSource": {"S3": {"S3Uri": "s3://bucket/code/code/"}}, + "OptimizationConfigs": [ + {"ModelQuantizationConfig": {"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}} + ], + "OutputConfig": {"S3OutputLocation": "s3://bucket/code/"}, + }, + ) + + @patch( + "sagemaker.serve.builder.model_builder.download_huggingface_model_metadata", autospec=True + ) + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_for_hf_without_custom_s3_path( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + mock_download_huggingface_model_metadata, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-8B-Instruc"} + + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + + model_builder.pysdk_model = mock_pysdk_model + + out_put = model_builder._optimize_for_hf( + job_name="job_name-123", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + output_path="s3://bucket/code/", + ) + + self.assertEqual(model_builder.role_arn, "role-arn") + self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge") + self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq") + self.assertEqual( + out_put, + { + "OptimizationJobName": "job_name-123", + "DeploymentInstanceType": "ml.g5.2xlarge", + "RoleArn": "role-arn", + "ModelSource": {"S3": {"S3Uri": "s3://bucket/code/code/"}}, + "OptimizationConfigs": [ + {"ModelQuantizationConfig": {"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}} + ], + "OutputConfig": {"S3OutputLocation": "s3://bucket/code/"}, + }, + ) + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_jumpstart") + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + def test_build_multiple_inference_component_modelbuilders( + self, + mock_pre_trained_model, + mock_is_jumpstart_model_id, + mock_build_for_js, + mock_serve_settings, + ): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder1 = ModelBuilder( + model="gpt_llm_burt", inference_component_name="ic1", resource_requirements=Mock() + ) + builder2 = ModelBuilder( + model="gpt_llm_burt", inference_component_name="ic2", resource_requirements=Mock() + ) + + builder3 = ModelBuilder( + model="gpt_llm_burt", inference_component_name="ic3", resource_requirements=Mock() + ) + + chain_builder = ModelBuilder( + modelbuilder_list=[builder1, builder2, builder3], + ) + chain_builder.build(sagemaker_session=mock_session) + assert mock_build_for_js.call_count == 3 + + @patch("sagemaker.serve.builder.model_builder._ServeSettings") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_jumpstart") + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._does_ic_exist", + return_value=True, + ) + @patch( + "sagemaker.session.Session.update_inference_component", + return_value=MagicMock(), + ) + def test_deploy_existing_inference_component_calls_update_inference_component( + self, + mock_update_inference_component, + mock_ic_exists, + mock_pre_trained_model, + mock_is_jumpstart_model_id, + mock_build_for_js, + mock_serve_settings, + ): + mock_setting_object = mock_serve_settings.return_value + mock_setting_object.role_arn = mock_role_arn + mock_setting_object.s3_model_data_url = mock_s3_model_data_url + + builder1 = ModelBuilder( + model="gpt_llm_burt", inference_component_name="ic1", resource_requirements=Mock() + ) + + chain_builder = ModelBuilder( + modelbuilder_list=[builder1], + ).build() + inputs = {"endpoint_name": "endpoint-001"} + chain_builder.deploy(**inputs) + assert mock_update_inference_component.call_count == 1 + + def test_deploy_invalid_inputs(self): + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + inputs = {"endpoint_name": "endpoint-001"} + + try: + model_builder.deploy(**inputs) + except ValueError as e: + assert "Model needs to be built before deploying" in str(e) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_display_benchmark_metrics_non_string_model(self, mock_is_jumpstart): + """Test that ValueError is raised when model is not a string""" + builder = ModelBuilder(model=Mock()) # Non-string model + + self.assertRaisesRegex( + ValueError, + "Benchmarking is only supported for JumpStart or HuggingFace models", + builder.display_benchmark_metrics, + ) + mock_is_jumpstart.assert_not_called() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.display_benchmark_metrics") + def test_display_benchmark_metrics_jumpstart_model( + self, mock_display_benchmark_metrics, mock_is_jumpstart + ): + """Test successful execution for jumpstart model""" + mock_is_jumpstart.return_value = True + + builder = ModelBuilder(model="jumpstart-model-id") + builder.display_benchmark_metrics() + + mock_is_jumpstart.assert_called_once() + mock_display_benchmark_metrics.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.display_benchmark_metrics") + def test_display_benchmark_metrics_with_jumpstart_equivalent( + self, mock_display_benchmark_metrics, mock_has_equivalent, mock_is_jumpstart + ): + """Test successful execution for model with jumpstart equivalent""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = True + + builder = ModelBuilder(model="hf-model-id") + builder.display_benchmark_metrics() + + mock_is_jumpstart.assert_called_once() + mock_has_equivalent.assert_called_once() + mock_display_benchmark_metrics.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + def test_display_benchmark_metrics_unsupported_model( + self, mock_has_equivalent, mock_is_jumpstart + ): + """Test that ValueError is raised for unsupported models""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = False + + builder = ModelBuilder(model="huggingface-model-id") + + self.assertRaisesRegex( + ValueError, + "This model does not have benchmark metrics yet", + builder.display_benchmark_metrics, + ) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_get_deployment_config_non_string_model(self, mock_is_jumpstart): + """Test that ValueError is raised when model is not a string""" + builder = ModelBuilder(model=Mock()) # Non-string model + + self.assertRaisesRegex( + ValueError, + "Deployment config is only supported for JumpStart or HuggingFace models", + builder.get_deployment_config, + ) + mock_is_jumpstart.assert_not_called() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.get_deployment_config") + def test_get_deployment_config_jumpstart_model( + self, mock_get_deployment_config, mock_is_jumpstart + ): + """Test successful execution for jumpstart model""" + mock_is_jumpstart.return_value = True + + builder = ModelBuilder(model="jumpstart-model-id") + builder.get_deployment_config() + + mock_is_jumpstart.assert_called_once() + mock_get_deployment_config.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.get_deployment_config") + def test_get_deployment_config_with_jumpstart_equivalent( + self, mock_get_deployment_config, mock_has_equivalent, mock_is_jumpstart + ): + """Test successful execution for model with jumpstart equivalent""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = True + + builder = ModelBuilder(model="hf-model-id") + builder.get_deployment_config() + + mock_is_jumpstart.assert_called_once() + mock_has_equivalent.assert_called_once() + mock_get_deployment_config.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + def test_get_deployment_config_unsupported_model(self, mock_has_equivalent, mock_is_jumpstart): + """Test that ValueError is raised for unsupported models""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = False + + builder = ModelBuilder(model="huggingface-model-id") + + self.assertRaisesRegex( + ValueError, + "This model does not have any deployment config yet", + builder.get_deployment_config, + ) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_list_deployment_configs_non_string_model(self, mock_is_jumpstart): + """Test that ValueError is raised when model is not a string""" + builder = ModelBuilder(model=Mock()) # Non-string model + + self.assertRaisesRegex( + ValueError, + "Deployment config is only supported for JumpStart or HuggingFace models", + builder.list_deployment_configs, + ) + mock_is_jumpstart.assert_not_called() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.list_deployment_configs") + def test_list_deployment_configs_jumpstart_model( + self, mock_list_deployment_configs, mock_is_jumpstart + ): + """Test successful execution for jumpstart model""" + mock_is_jumpstart.return_value = True + + builder = ModelBuilder(model="jumpstart-model-id") + builder.list_deployment_configs() + + mock_is_jumpstart.assert_called_once() + mock_list_deployment_configs.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.list_deployment_configs") + def test_list_deployment_configs_with_jumpstart_equivalent( + self, mock_list_deployment_configs, mock_has_equivalent, mock_is_jumpstart + ): + """Test successful execution for model with jumpstart equivalent""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = True + + builder = ModelBuilder(model="hf-model-id") + builder.list_deployment_configs() + + mock_is_jumpstart.assert_called_once() + mock_has_equivalent.assert_called_once() + mock_list_deployment_configs.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + def test_list_deployment_configs_unsupported_model( + self, mock_has_equivalent, mock_is_jumpstart + ): + """Test that ValueError is raised for unsupported models""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = False + + builder = ModelBuilder(model="huggingface-model-id") + + self.assertRaisesRegex( + ValueError, + "This model does not have any deployment config yet", + builder.list_deployment_configs, + ) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_set_deployment_config_non_string_model(self, mock_is_jumpstart): + """Test that ValueError is raised when model is not a string""" + builder = ModelBuilder(model=Mock()) # Non-string model + instance_type = "ml.g5.xlarge" + config_name = "config-name" + self.assertRaisesRegex( + ValueError, + "Deployment config is only supported for JumpStart or HuggingFace models", + builder.set_deployment_config, + config_name, + instance_type, + ) + mock_is_jumpstart.assert_not_called() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.set_deployment_config") + def test_set_deployment_config_jumpstart_model( + self, mock_set_deployment_config, mock_is_jumpstart + ): + """Test successful execution for jumpstart model""" + mock_is_jumpstart.return_value = True + instance_type = "ml.g5.xlarge" + config_name = "config-name" + + builder = ModelBuilder(model="jumpstart-model-id") + builder.set_deployment_config(config_name, instance_type) + + mock_is_jumpstart.assert_called_once() + mock_set_deployment_config.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + @patch("sagemaker.serve.builder.jumpstart_builder.JumpStart.set_deployment_config") + def test_set_deployment_config_with_jumpstart_equivalent( + self, mock_set_deployment_config, mock_has_equivalent, mock_is_jumpstart + ): + """Test successful execution for model with jumpstart equivalent""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = True + instance_type = "ml.g5.xlarge" + config_name = "config-name" + + builder = ModelBuilder(model="hf-model-id") + builder.set_deployment_config(config_name, instance_type) + + mock_is_jumpstart.assert_called_once() + mock_has_equivalent.assert_called_once() + mock_set_deployment_config.assert_called_once() + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_jumpstart_model_id") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._use_jumpstart_equivalent") + def test_set_deployment_config_unsupported_model(self, mock_has_equivalent, mock_is_jumpstart): + """Test that ValueError is raised for unsupported models""" + mock_is_jumpstart.return_value = False + mock_has_equivalent.return_value = False + instance_type = "ml.g5.xlarge" + config_name = "config-name" + + builder = ModelBuilder(model="huggingface-model-id") + + self.assertRaisesRegex( + ValueError, + f"The deployment config {config_name} cannot be set on this model", + builder.set_deployment_config, + config_name, + instance_type, + ) + + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._retrieve_hugging_face_model_mapping" + ) + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_gated_model") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_jumpstart") + def test_use_jumpstart_equivalent_return_true( + self, mock_build_for_jumpstart, mock_is_gated_model, mock_retrieve_mapping + ): + """Test that _use_jumpstart_equivalent returns True when equivalent exists""" + mock_retrieve_mapping.return_value = { + "HuggingFaceH4/zephyr-7b-beta": { + "jumpstart-model-id": "js-model", + "jumpstart-model-version": "1.0.0", + "hf-model-repo-sha": None, + } + } + mock_is_gated_model.return_value = False + + builder = ModelBuilder(model="HuggingFaceH4/zephyr-7b-beta") + + self.assertTrue(builder._use_jumpstart_equivalent()) + + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._retrieve_hugging_face_model_mapping" + ) + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._is_gated_model") + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_jumpstart") + def test_use_jumpstart_equivalent_return_true_with_schema_builder( + self, mock_build_for_jumpstart, mock_is_gated_model, mock_retrieve_mapping + ): + """Test that _use_jumpstart_equivalent returns True when equivalent exists""" + mock_retrieve_mapping.return_value = { + "HuggingFaceH4/zephyr-7b-beta": { + "jumpstart-model-id": "js-model", + "jumpstart-model-version": "1.0.0", + "hf-model-repo-sha": None, + } + } + mock_is_gated_model.return_value = False + + builder = ModelBuilder(model="HuggingFaceH4/zephyr-7b-beta", sagemaker_session=mock_session) + + self.assertTrue(builder._use_jumpstart_equivalent()) + self.assertIsNotNone(builder.schema_builder) + inputs, outputs = task.retrieve_local_schemas("text-generation") + self.assertEqual(builder.schema_builder.sample_input["inputs"], inputs["inputs"]) + self.assertEqual(builder.schema_builder.sample_output, outputs) + + @patch( + "sagemaker.serve.builder.model_builder.ModelBuilder._retrieve_hugging_face_model_mapping" + ) + def test_use_jumpstart_equivalent_return_false(self, mock_retrieve_mapping): + """Test that _use_jumpstart_equivalent returns false when equivalent doesn't exist""" + mock_retrieve_mapping.return_value = { + "hf-model-id": { + "jumpstart-model-id": "js-model", + "jumpstart-model-version": "1.0.0", + "hf-model-repo-sha": None, + } + } + + builder = ModelBuilder(model="model-id") + + self.assertFalse(builder._use_jumpstart_equivalent()) + + def test_use_jumpstart_equivalent_return_false_with_env_vars(self): + """Test that _use_jumpstart_equivalent returns false when env_vars is provided""" + builder = ModelBuilder(model="model-id", env_vars={"mock-key": "mock-value"}) + + self.assertFalse(builder._use_jumpstart_equivalent()) + + def test_use_jumpstart_equivalent_return_false_with_image_uri(self): + """Test that _use_jumpstart_equivalent returns false when image_uri is provided""" + builder = ModelBuilder(model="model-id", image_uri="mock-uri") + + self.assertFalse(builder._use_jumpstart_equivalent()) + + @patch("sagemaker.serve.builder.model_builder.JumpStartS3PayloadAccessor") + @patch("sagemaker.serve.builder.model_builder.get_jumpstart_content_bucket") + def test_retrieve_hugging_face_model_mapping(self, mock_content_bucket, mock_payload_accessor): + """Test that _retrieve_hugging_face_model_mapping returns the correct mapping""" + mock_get_object = Mock() + mock_get_object.return_value = ( + '{"js-model-id": {"hf-model-id": "hf-model", "jumpstart-model-version": "1.0.0"}}' + ) + mock_payload_accessor.get_object_cached = mock_get_object + expected_mapping = { + "hf-model": { + "jumpstart-model-id": "js-model-id", + "jumpstart-model-version": "1.0.0", + "hf-model-repo-sha": None, + "merged-at": None, + } + } + + builder = ModelBuilder(model="hf-model", sagemaker_session=mock_session) + + self.assertEqual(builder._retrieve_hugging_face_model_mapping(), expected_mapping) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "meta-llama/Meta-Llama-3-2-8B-Instruct"} + + sample_input = {"inputs": "dummy prompt", "parameters": {}} + + sample_output = [{"generated_text": "dummy response"}] + + dummy_schema_builder = SchemaBuilder(sample_input, sample_output) + + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-2-8B-Instruct", + schema_builder=dummy_schema_builder, + env_vars={"HF_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + + model_builder.pysdk_model = mock_pysdk_model + + self.assertRaisesRegex( + ValueError, + "Compilation is not supported for models greater than Llama-3.0 with a GPU instance.", + lambda: model_builder.optimize( + job_name="job_name-123", + instance_type="ml.g5.24xlarge", + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + output_path="s3://bucket/code/", + ), + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_with_gpu_instance_and_compilation_with_speculative_decoding( + self, + mock_get_serve_setting, + mock_prepare_for_mode, + ): + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://bucket/code/code/", + } + }, + {"DTYPE": "bfloat16"}, + ) + + mock_pysdk_model = Mock() + mock_pysdk_model.model_data = None + mock_pysdk_model.env = {"HF_MODEL_ID": "modelid"} + + sample_input = {"inputs": "dummy prompt", "parameters": {}} + + sample_output = [{"generated_text": "dummy response"}] + + dummy_schema_builder = SchemaBuilder(sample_input, sample_output) + + model_builder = ModelBuilder( + model="modelid", + schema_builder=dummy_schema_builder, + env_vars={"HF_TOKEN": "token"}, + model_metadata={ + "CUSTOM_MODEL_PATH": "s3://bucket/path/", + }, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + + model_builder.pysdk_model = mock_pysdk_model + + self.assertRaisesRegex( + ValueError, + "Optimizations that use Compilation and Speculative Decoding are not supported for GPU instances.", + lambda: model_builder.optimize( + job_name="job_name-123", + instance_type="ml.g5.24xlarge", + speculative_decoding_config={ + "ModelProvider": "custom", + "ModelSource": "s3://data-source", + }, + compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, + output_path="s3://bucket/code/", + ), + ) + + +class TestModelBuilderOptimizationSharding(unittest.TestCase): + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_djl") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize_sharding_with_env_vars( + self, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_djl, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_djl.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"key": "value"} + env_vars = {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) + + assert mock_send_telemetry.call_count == 2 + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + DeploymentInstanceType="ml.g5.24xlarge", + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={"OPTION_TENSOR_PARALLEL_DEGREE": "1"}, + ModelSource={"S3": {"S3Uri": "s3://uri"}}, + OptimizationConfigs=[{"ModelShardingConfig": {"key": "value"}}], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_djl") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize_sharding_with_override_and_env_var( + self, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_djl, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_djl.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}} + env_vars = {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) + + assert mock_send_telemetry.call_count == 2 + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + DeploymentInstanceType="ml.g5.24xlarge", + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={"OPTION_TENSOR_PARALLEL_DEGREE": "1"}, + ModelSource={"S3": {"S3Uri": "s3://uri"}}, + OptimizationConfigs=[ + { + "ModelShardingConfig": { + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + } + } + ], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_djl") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize_sharding_with_override( + self, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_djl, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_djl.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}} + env_vars = {"Var1": "value", "Var2": "value"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) + + assert mock_send_telemetry.call_count == 2 + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + DeploymentInstanceType="ml.g5.24xlarge", + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={"Var1": "value", "Var2": "value"}, + ModelSource={"S3": {"S3Uri": "s3://uri"}}, + OptimizationConfigs=[ + { + "ModelShardingConfig": { + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + } + } + ], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + # squeeze in some validations + with self.assertRaises(ValueError): + builder.enable_network_isolation = True + builder.optimize(sharding_config={}) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_jumpstart") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=True) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", return_value=False + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._find_compatible_deployment_config", + return_value=Mock(), + ) + def test_optimize_sharding_with_override_for_js( + self, + mock_find_compatible_deployment_config, + mock_is_gated_model, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_jumpstart, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model._enable_network_isolation = True + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + pysdk_model.init_kwargs = { + "image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124" + } + + mock_build_for_jumpstart.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}} + env_vars = {"Var1": "value", "Var2": "value"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + model = builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + + assert mock_send_telemetry.call_count == 2 + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + ModelSource={"S3": {"S3Uri": ANY}}, + DeploymentInstanceType="ml.g5.24xlarge", + OptimizationConfigs=[ + { + "ModelShardingConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124", + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}, + }, + } + ], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={ + "key": "val", + "Var1": "value", + "Var2": "value", + "OPTION_TENSOR_PARALLEL_DEGREE": "1", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + assert not model._enable_network_isolation + + def test_model_sharding_with_eni_fails(self): + test_model = Model(role="mock role") + test_model._is_sharded_model = True + test_model._enable_network_isolation = True + self.assertRaisesRegex( + ValueError, + ( + "EnableNetworkIsolation cannot be set to True since " + "SageMaker Fast Model Loading of model requires network access." + ), + lambda: test_model.deploy(initial_instance_count=1, instance_type="ml.g5.24xlarge"), + ) + + +class TestModelBuilderOptimizeValidations(unittest.TestCase): + + def test_corner_cases_throw_errors(self): + self.assertRaisesRegex( + ValueError, + "Optimizations that uses None instance type are not currently supported", + lambda: _validate_optimization_configuration( + is_jumpstart=False, + sharding_config={"key": "value"}, + instance_type=None, + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + ), + ) + + self.assertRaisesRegex( + ValueError, + ( + "Optimizations that provide no optimization configs " + "are currently not support on both GPU and Neuron instances." + ), + lambda: _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + sharding_config=None, + ), + ) + + _validate_optimization_configuration( + is_jumpstart=True, + instance_type="ml.inf2.xlarge", + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + sharding_config=None, + ) + + def test_trt_and_vllm_configurations_throw_errors_for_rule_set(self): + # Quantization:smoothquant without compilation + self.assertRaisesRegex( + ValueError, + "Optimizations that use Quantization:smoothquant must be provided with Compilation for GPU instances.", + lambda: _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "smoothquant"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config=None, + ), + ) + + # Invalid quantization technique + self.assertRaisesRegex( + ValueError, + "Optimizations that use Quantization:test are not supported for GPU instances.", + lambda: _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "test"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config=None, + ), + ) + + def test_neuron_configurations_throw_errors_for_rule_set(self): + self.assertRaisesRegex( + ValueError, + "Optimizations that use Speculative Decoding are not supported on Neuron instances.", + lambda: _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.inf2.xlarge", + quantization_config=None, + speculative_decoding_config={"key": "value"}, + compilation_config=None, + sharding_config=None, + ), + ) + + self.assertRaisesRegex( + ValueError, + "Optimizations that use Sharding are not supported on Neuron instances.", + lambda: _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.inf2.xlarge", + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + sharding_config={"key": "value"}, + ), + ) + + def test_trt_configurations_rule_set(self): + # Can be compiled with quantization + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "smoothquant"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={"key": "value"}, + ), + + # Can be just compiled + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={"key": "value"}, + ) + + # Can be just compiled with empty dict + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={}, + ) + + def test_vllm_configurations_rule_set(self): + # Can use speculative decoding + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config={"key": "value"}, + compilation_config=None, + ) + + # Can be quantized + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config=None, + ) + + # Can be sharded + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config={"key": "value"}, + speculative_decoding_config=None, + compilation_config=None, + ) + + def test_neuron_configurations_rule_set(self): + # Can be compiled + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.inf2.xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={"key": "value"}, + ) + + # Can be compiled with empty dict + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.inf2.xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={}, + ) + + +@pytest.mark.parametrize( + "test_case", + [ + # Real-time deployment without update + { + "input_args": {"endpoint_name": "test"}, + "call_params": { + "instance_type": "ml.g5.2xlarge", + "initial_instance_count": 1, + "endpoint_name": "test", + "update_endpoint": False, + }, + }, + # Real-time deployment with update + { + "input_args": { + "endpoint_name": "existing-endpoint", + "update_endpoint": True, + }, + "call_params": { + "instance_type": "ml.g5.2xlarge", + "initial_instance_count": 1, + "endpoint_name": "existing-endpoint", + "update_endpoint": True, + }, + }, + # Serverless deployment without update + { + "input_args": { + "endpoint_name": "test", + "inference_config": ServerlessInferenceConfig(), + }, + "call_params": { + "serverless_inference_config": ServerlessInferenceConfig(), + "endpoint_name": "test", + "update_endpoint": False, + }, + }, + # Serverless deployment with update + { + "input_args": { + "endpoint_name": "existing-endpoint", + "inference_config": ServerlessInferenceConfig(), + "update_endpoint": True, + }, + "call_params": { + "serverless_inference_config": ServerlessInferenceConfig(), + "endpoint_name": "existing-endpoint", + "update_endpoint": True, + }, + }, + # Async deployment without update + { + "input_args": { + "endpoint_name": "test", + "inference_config": AsyncInferenceConfig(output_path="op-path"), + }, + "call_params": { + "async_inference_config": AsyncInferenceConfig(output_path="op-path"), + "instance_type": "ml.g5.2xlarge", + "initial_instance_count": 1, + "endpoint_name": "test", + "update_endpoint": False, + }, + }, + # Async deployment with update + { + "input_args": { + "endpoint_name": "existing-endpoint", + "inference_config": AsyncInferenceConfig(output_path="op-path"), + "update_endpoint": True, + }, + "call_params": { + "async_inference_config": AsyncInferenceConfig(output_path="op-path"), + "instance_type": "ml.g5.2xlarge", + "initial_instance_count": 1, + "endpoint_name": "existing-endpoint", + "update_endpoint": True, + }, + }, + # Multi-Model deployment (update_endpoint not supported) + { + "input_args": { + "endpoint_name": "test", + "inference_config": RESOURCE_REQUIREMENTS, + }, + "call_params": { + "resources": RESOURCE_REQUIREMENTS, + "role": "role-arn", + "initial_instance_count": 1, + "instance_type": "ml.g5.2xlarge", + "mode": Mode.SAGEMAKER_ENDPOINT, + "endpoint_type": EndpointType.INFERENCE_COMPONENT_BASED, + "update_endpoint": False, + }, + }, + # Batch transform + { + "input_args": { + "inference_config": BatchTransformInferenceConfig( + instance_count=1, instance_type="ml.m5.large", output_path="op-path" + ) + }, + "call_params": { + "instance_count": 1, + "instance_type": "ml.m5.large", + "output_path": "op-path", + }, + "id": "Batch", + }, + ], + ids=[ + "Real Time", + "Real Time Update", + "Serverless", + "Serverless Update", + "Async", + "Async Update", + "Multi-Model", + "Batch", + ], +) +def test_deploy(test_case): + model: Model = MagicMock() + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + setattr(model_builder, "built_model", model) + + model_builder.deploy(**test_case["input_args"]) + + if "id" in test_case and test_case["id"] == "Batch": + args, kwargs = model.transformer.call_args_list[0] + else: + args, kwargs = model.deploy.call_args_list[0] + + diff = deepdiff.DeepDiff(kwargs, test_case["call_params"]) + assert diff == {} + + +def test_deploy_multi_model_update_error(): + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + setattr(model_builder, "built_model", MagicMock()) + + with pytest.raises( + ValueError, match="Currently update_endpoint is supported for single model endpoints" + ): + model_builder.deploy( + endpoint_name="test", inference_config=RESOURCE_REQUIREMENTS, update_endpoint=True + ) diff --git a/tests/unit/sagemaker/serve/builder/test_requirements_manager.py b/tests/unit/sagemaker/serve/builder/test_requirements_manager.py new file mode 100644 index 0000000000..b6886ab0a6 --- /dev/null +++ b/tests/unit/sagemaker/serve/builder/test_requirements_manager.py @@ -0,0 +1,81 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +from unittest.mock import patch, call + +from sagemaker.serve.builder.requirements_manager import RequirementsManager + + +class TestRequirementsManager(unittest.TestCase): + + @patch( + "sagemaker.serve.builder.requirements_manager.RequirementsManager._update_conda_env_in_path" + ) + @patch( + "sagemaker.serve.builder.requirements_manager.RequirementsManager._install_requirements_txt" + ) + @patch( + "sagemaker.serve.builder.requirements_manager.RequirementsManager._detect_conda_env_and_local_dependencies" + ) + def test_capture_and_install_dependencies_txt( + self, + mock_detect_conda_env_and_local_dependencies, + mock_install_requirements_txt, + mock_update_conda_env_in_path, + ) -> str: + + mock_detect_conda_env_and_local_dependencies.side_effect = lambda: ".txt" + RequirementsManager().capture_and_install_dependencies() + mock_install_requirements_txt.assert_called_once() + + RequirementsManager().capture_and_install_dependencies("conda.yml") + mock_update_conda_env_in_path.assert_called_once() + + @patch( + "sagemaker.serve.builder.requirements_manager.RequirementsManager._detect_conda_env_and_local_dependencies" + ) + def test_capture_and_install_dependencies_fail( + self, mock_detect_conda_env_and_local_dependencies + ) -> str: + mock_dependencies = "mock.ini" + mock_detect_conda_env_and_local_dependencies.side_effect = lambda: "invalid requirement" + self.assertRaises( + ValueError, + lambda: RequirementsManager().capture_and_install_dependencies(mock_dependencies), + ) + + @patch("sagemaker.serve.builder.requirements_manager.logger") + @patch("sagemaker.serve.builder.requirements_manager.subprocess") + def test_install_requirements_txt(self, mock_subprocess, mock_logger): + + RequirementsManager()._install_requirements_txt() + + calls = [call("Running command to pip install"), call("Command ran successfully")] + mock_logger.info.assert_has_calls(calls) + mock_subprocess.run.assert_called_once_with( + "pip install -r in_process_requirements.txt", shell=True, check=True + ) + + @patch("sagemaker.serve.builder.requirements_manager.logger") + @patch("sagemaker.serve.builder.requirements_manager.subprocess") + def test_update_conda_env_in_path(self, mock_subprocess, mock_logger): + + RequirementsManager()._update_conda_env_in_path() + + calls = [call("Updating conda env"), call("Conda env updated successfully")] + mock_logger.info.assert_has_calls(calls) + mock_subprocess.run.assert_called_once_with( + "conda env update -f conda_in_process.yml", shell=True, check=True + ) diff --git a/tests/unit/sagemaker/serve/builder/test_tei_builder.py b/tests/unit/sagemaker/serve/builder/test_tei_builder.py new file mode 100644 index 0000000000..74e49e345f --- /dev/null +++ b/tests/unit/sagemaker/serve/builder/test_tei_builder.py @@ -0,0 +1,236 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from unittest.mock import MagicMock, patch + +import unittest +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.mode.function_pointers import Mode +from tests.unit.sagemaker.serve.constants import MOCK_VPC_CONFIG + +from sagemaker.serve.utils.predictors import TeiLocalModePredictor + +MOCK_MODEL_ID = "bert-base-uncased" +MOCK_PROMPT = "The man worked as a [MASK]." +MOCK_SAMPLE_INPUT = {"inputs": MOCK_PROMPT} +MOCK_SAMPLE_OUTPUT = [ + { + "score": 0.0974755585193634, + "token": 10533, + "token_str": "carpenter", + "sequence": "the man worked as a carpenter.", + }, + { + "score": 0.052383411675691605, + "token": 15610, + "token_str": "waiter", + "sequence": "the man worked as a waiter.", + }, + { + "score": 0.04962712526321411, + "token": 13362, + "token_str": "barber", + "sequence": "the man worked as a barber.", + }, + { + "score": 0.0378861166536808, + "token": 15893, + "token_str": "mechanic", + "sequence": "the man worked as a mechanic.", + }, + { + "score": 0.037680838257074356, + "token": 18968, + "token_str": "salesman", + "sequence": "the man worked as a salesman.", + }, +] +MOCK_SCHEMA_BUILDER = MagicMock() +MOCK_SCHEMA_BUILDER.sample_input = MOCK_SAMPLE_INPUT +MOCK_SCHEMA_BUILDER.sample_output = MOCK_SAMPLE_OUTPUT +MOCK_IMAGE_CONFIG = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0" +) +MOCK_MODEL_PATH = "mock model path" + + +class TestTEIBuilder(unittest.TestCase): + @patch( + "sagemaker.serve.builder.tei_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None) + def test_tei_builder_sagemaker_endpoint_mode_no_s3_upload_success( + self, + mock_get_nb_instance, + mock_telemetry, + ): + # verify SAGEMAKER_ENDPOINT deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + name="mock_model_name", + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.SAGEMAKER_ENDPOINT, + model_metadata={ + "HF_TASK": "sentence-similarity", + }, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model = builder.build() + assert model.name == "mock_model_name" + + builder.serve_settings.telemetry_opt_out = True + builder._original_deploy = MagicMock() + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.assert_called_with() + + @patch( + "sagemaker.serve.builder.tei_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None) + def test_tei_builder_overwritten_deploy_from_local_container_to_sagemaker_endpoint_success( + self, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + vpc_config=MOCK_VPC_CONFIG, + model_metadata={ + "HF_TASK": "sentence-similarity", + }, + model_path=MOCK_MODEL_PATH, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + + predictor = model.deploy(model_data_download_timeout=1800) + + assert model.vpc_config == MOCK_VPC_CONFIG + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" + assert isinstance(predictor, TeiLocalModePredictor) + assert builder.nb_instance_type == "ml.g5.24xlarge" + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.call_args_list[1].assert_called_once_with( + model_path=MOCK_MODEL_PATH, should_upload_artifacts=True + ) + + @patch( + "sagemaker.serve.builder.tei_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None) + @patch("sagemaker.serve.builder.tei_builder._is_optimized", return_value=True) + def test_tei_builder_optimized_sagemaker_endpoint_mode_no_s3_upload_success( + self, + mock_is_optimized, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + vpc_config=MOCK_VPC_CONFIG, + model_metadata={ + "HF_TASK": "sentence-similarity", + }, + model_path=MOCK_MODEL_PATH, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + + model.deploy(model_data_download_timeout=1800) + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + # verify that if optimized, no s3 upload occurs + builder._prepare_for_mode.assert_called_with() + + @patch( + "sagemaker.serve.builder.tei_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tei_builder._capture_telemetry", side_effect=None) + def test_tei_builder_image_uri_override_success( + self, + mock_get_nb_instance, + mock_telemetry, + ): + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + image_uri=MOCK_IMAGE_CONFIG, + model_metadata={ + "HF_TASK": "sentence-similarity", + }, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + predictor = model.deploy(model_data_download_timeout=1800) + + assert builder.image_uri == MOCK_IMAGE_CONFIG + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" + assert isinstance(predictor, TeiLocalModePredictor) + + assert builder.nb_instance_type == "ml.g5.24xlarge" + + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + predictor = model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + assert "HF_MODEL_ID" in model.env + + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) diff --git a/tests/unit/sagemaker/serve/builder/test_tensorflow_serving_builder.py b/tests/unit/sagemaker/serve/builder/test_tensorflow_serving_builder.py new file mode 100644 index 0000000000..e8ae892b45 --- /dev/null +++ b/tests/unit/sagemaker/serve/builder/test_tensorflow_serving_builder.py @@ -0,0 +1,77 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from unittest.mock import MagicMock, patch + +import unittest +from pathlib import Path + +from sagemaker.serve import ModelBuilder, ModelServer + + +class TestTransformersBuilder(unittest.TestCase): + def setUp(self): + self.instance = ModelBuilder() + self.instance.model_server = ModelServer.TENSORFLOW_SERVING + self.instance.model_path = "/fake/model/path" + self.instance.image_uri = "fake_image_uri" + self.instance.s3_upload_path = "s3://bucket/path" + self.instance.serve_settings = MagicMock(role_arn="fake_role_arn") + self.instance.schema_builder = MagicMock() + self.instance.env_vars = {} + self.instance.sagemaker_session = MagicMock() + self.instance.image_config = {} + self.instance.vpc_config = {} + self.instance.modes = {} + self.instance.name = "model-name-mock-uuid-hex" + + @patch("os.makedirs") + @patch("os.path.exists") + @patch("sagemaker.serve.builder.tf_serving_builder.save_pkl") + def test_save_schema_builder(self, mock_save_pkl, mock_exists, mock_makedirs): + mock_exists.return_value = False + self.instance._save_schema_builder() + mock_makedirs.assert_called_once_with(self.instance.model_path) + code_path = Path(self.instance.model_path).joinpath("code") + mock_save_pkl.assert_called_once_with(code_path, self.instance.schema_builder) + + @patch("sagemaker.serve.builder.tf_serving_builder.TensorflowServing._get_client_translators") + @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowPredictor") + def test_get_tensorflow_predictor(self, mock_predictor, mock_get_marshaller): + endpoint_name = "test_endpoint" + predictor = self.instance._get_tensorflow_predictor( + endpoint_name, self.instance.sagemaker_session + ) + mock_predictor.assert_called_once_with( + endpoint_name=endpoint_name, + sagemaker_session=self.instance.sagemaker_session, + serializer=self.instance.schema_builder.custom_input_translator, + deserializer=self.instance.schema_builder.custom_output_translator, + ) + self.assertEqual(predictor, mock_predictor.return_value) + + @patch("sagemaker.serve.builder.tf_serving_builder.TensorFlowModel") + def test_create_tensorflow_model(self, mock_model): + model = self.instance._create_tensorflow_model() + mock_model.assert_called_once_with( + image_uri=self.instance.image_uri, + image_config=self.instance.image_config, + vpc_config=self.instance.vpc_config, + model_data=self.instance.s3_upload_path, + role=self.instance.serve_settings.role_arn, + env=self.instance.env_vars, + sagemaker_session=self.instance.sagemaker_session, + predictor_cls=self.instance._get_tensorflow_predictor, + name="model-name-mock-uuid-hex", + ) + self.assertEqual(model, mock_model.return_value) diff --git a/tests/unit/sagemaker/serve/builder/test_tgi_builder.py b/tests/unit/sagemaker/serve/builder/test_tgi_builder.py new file mode 100644 index 0000000000..22109c93e2 --- /dev/null +++ b/tests/unit/sagemaker/serve/builder/test_tgi_builder.py @@ -0,0 +1,310 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest import TestCase +from unittest.mock import MagicMock, patch +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.utils.predictors import TgiLocalModePredictor + +MOCK_MODEL_ID_GATED = "meta-llama/Meta-Llama-3-8B" +MOCK_MODEL_ID_NON_GATED = "openai-community/gpt2.0" +MOCK_PROMPT = "The man worked as a [MASK]." +MOCK_SAMPLE_INPUT = {"inputs": "Hello, I'm a language model", "parameters": {"max_new_tokens": 128}} +MOCK_SAMPLE_OUTPUT = [{"generated_text": "Hello, I'm a language modeler."}] +MOCK_SCHEMA_BUILDER = MagicMock() +MOCK_SCHEMA_BUILDER.sample_input = MOCK_SAMPLE_INPUT +MOCK_SCHEMA_BUILDER.sample_output = MOCK_SAMPLE_OUTPUT +MOCK_IMAGE_CONFIG = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0" +) +MOCK_MODEL_PATH = "mock model path" + + +class TestTGIBuilder(TestCase): + @patch( + "sagemaker.serve.builder.tgi_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tgi_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.model_builder.get_huggingface_model_metadata", + return_value={"pipeline_tag": "text-generation"}, + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_model_config_properties_from_hf", + return_value=({}, None), + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_default_tgi_configurations", + return_value=({}, None), + ) + def test_tgi_builder_sagemaker_endpoint_mode_no_s3_upload_success( + self, + mock_default_tgi_configurations, + mock_hf_model_config, + mock_hf_model_md, + mock_get_nb_instance, + mock_telemetry, + ): + # verify SAGEMAKER_ENDPOINT deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID_NON_GATED, + name="mock_model_name", + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.SAGEMAKER_ENDPOINT, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model = builder.build() + assert model.name == "mock_model_name" + + builder.serve_settings.telemetry_opt_out = True + builder._original_deploy = MagicMock() + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.assert_called_with() + + @patch( + "sagemaker.serve.builder.tgi_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tgi_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.model_builder.get_huggingface_model_metadata", + return_value={"pipeline_tag": "text-generation"}, + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_model_config_properties_from_hf", + return_value=({}, None), + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_default_tgi_configurations", + return_value=({}, None), + ) + def test_tgi_builder_overwritten_deploy_from_local_container_to_sagemaker_endpoint_success( + self, + mock_default_tgi_configurations, + mock_hf_model_config, + mock_hf_model_md, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID_NON_GATED, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + model_path=MOCK_MODEL_PATH, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + + predictor = model.deploy(model_data_download_timeout=1800) + + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" + assert isinstance(predictor, TgiLocalModePredictor) + assert builder.nb_instance_type == "ml.g5.24xlarge" + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.call_args_list[1].assert_called_once_with( + model_path=MOCK_MODEL_PATH, should_upload_artifacts=True + ) + + @patch( + "sagemaker.serve.builder.tgi_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tgi_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.model_builder.get_huggingface_model_metadata", + return_value={"pipeline_tag": "text-generation"}, + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_model_config_properties_from_hf", + return_value=({}, None), + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_default_tgi_configurations", + return_value=({}, None), + ) + @patch("sagemaker.serve.builder.tgi_builder._is_optimized", return_value=True) + def test_tgi_builder_optimized_sagemaker_endpoint_mode_no_s3_upload_success( + self, + mock_is_optimized, + mock_default_tgi_configurations, + mock_hf_model_config, + mock_hf_model_md, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID_NON_GATED, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + model_path=MOCK_MODEL_PATH, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + + model.deploy(model_data_download_timeout=1800) + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + # verify that if optimized, no s3 upload occurs + builder._prepare_for_mode.assert_called_with() + + @patch( + "sagemaker.serve.builder.tgi_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tgi_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.model_builder.get_huggingface_model_metadata", + return_value={"pipeline_tag": "text-generation"}, + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_model_config_properties_from_hf", + return_value=({}, None), + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_default_tgi_configurations", + return_value=({}, None), + ) + def test_tgi_builder_in_process_mode( + self, + mock_default_tgi_configurations, + mock_hf_model_config, + mock_hf_model_md, + mock_get_nb_instance, + mock_telemetry, + ): + # verify IN_PROCESS deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID_GATED, schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.IN_PROCESS + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.IN_PROCESS)] = MagicMock() + + model.deploy() + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + # verify that if optimized, no s3 upload occurs + builder._prepare_for_mode.assert_called_with() + + @patch( + "sagemaker.serve.builder.tgi_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.tgi_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.model_builder.get_huggingface_model_metadata", + return_value={"pipeline_tag": "text-generation"}, + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_model_config_properties_from_hf", + return_value=({}, None), + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_default_tgi_configurations", + return_value=({}, None), + ) + @patch( + "sagemaker.serve.builder.tgi_builder._get_admissible_tensor_parallel_degrees", + return_value=[4, 8], + ) + @patch("sagemaker.serve.builder.tgi_builder._get_admissible_dtypes", return_value=["fp16"]) + @patch("sagemaker.serve.builder.tgi_builder.datetime") + @patch("sagemaker.serve.builder.tgi_builder.timedelta", return_value=1800) + @patch("sagemaker.serve.builder.tgi_builder._serial_benchmark") + @patch("sagemaker.serve.builder.tgi_builder._concurrent_benchmark") + def test_tgi_builder_tune_success( + self, + mock_concurrent_benchmark, + mock_serial_benchmark, + mock_timedelta, + mock_datetime, + mock_get_admissible_dtypes, + mock_get_admissible_tensor_parallel_degrees, + mock_default_tgi_configurations, + mock_hf_model_config, + mock_hf_model_md, + mock_get_nb_instance, + mock_telemetry, + ): + # WHERE + mock_datetime.now.side_effect = [0, 100, 200] + mock_serial_benchmark.side_effect = [(1000, 10000, 10), (500, 5000, 50)] + mock_concurrent_benchmark.side_effect = [(10, 10), (50, 5)] + + builder = ModelBuilder( + model=MOCK_MODEL_ID_NON_GATED, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + model_path=MOCK_MODEL_PATH, + ) + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + + model = builder.build() + + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + builder.pysdk_model = MagicMock() + + # WHEN + ret_new_model = model.tune(max_tuning_duration=1800) + + # THEN + assert ret_new_model != model + assert len(mock_datetime.now.call_args_list) == 3 + assert len(mock_serial_benchmark.call_args_list) == 2 + assert len(mock_concurrent_benchmark.call_args_list) == 2 + assert ret_new_model.env["NUM_SHARD"] == "8" + assert ret_new_model.env["DTYPE"] == "fp16" + assert ret_new_model.env["SHARDED"] == "true" diff --git a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py index e17364f22d..a5e269ea51 100644 --- a/tests/unit/sagemaker/serve/builder/test_transformers_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_transformers_builder.py @@ -20,10 +20,10 @@ from sagemaker.serve.utils.predictors import TransformersLocalModePredictor -mock_model_id = "bert-base-uncased" -mock_prompt = "The man worked as a [MASK]." -mock_sample_input = {"inputs": mock_prompt} -mock_sample_output = [ +MOCK_MODEL_ID = "bert-base-uncased" +MOCK_PROMPT = "The man worked as a [MASK]." +MOCK_SAMPLE_INPUT = {"inputs": MOCK_PROMPT} +MOCK_SAMPLE_OUTPUT = [ { "score": 0.0974755585193634, "token": 10533, @@ -55,9 +55,14 @@ "sequence": "the man worked as a salesman.", }, ] -mock_schema_builder = MagicMock() -mock_schema_builder.sample_input = mock_sample_input -mock_schema_builder.sample_output = mock_sample_output +MOCK_SCHEMA_BUILDER = MagicMock() +MOCK_SCHEMA_BUILDER.sample_input = MOCK_SAMPLE_INPUT +MOCK_SCHEMA_BUILDER.sample_output = MOCK_SAMPLE_OUTPUT +MOCK_IMAGE_CONFIG = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04-v1.0" +) +MOCK_MODEL_PATH = "mock model path" class TestTransformersBuilder(unittest.TestCase): @@ -66,16 +71,126 @@ class TestTransformersBuilder(unittest.TestCase): return_value="ml.g5.24xlarge", ) @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) - def test_build_deploy_for_transformers_local_container_and_remote_container( + def test_transformers_builder_sagemaker_endpoint_mode_no_s3_upload_success( self, mock_get_nb_instance, mock_telemetry, ): + # verify SAGEMAKER_ENDPOINT deploy builder = ModelBuilder( - model=mock_model_id, - schema_builder=mock_schema_builder, + model=MOCK_MODEL_ID, schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.SAGEMAKER_ENDPOINT + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder._original_deploy = MagicMock() + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.assert_called_with() + + @patch( + "sagemaker.serve.builder.transformers_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) + def test_transformers_builder_overwritten_deploy_from_local_container_to_sagemaker_endpoint_success( + self, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + vpc_config=MOCK_VPC_CONFIG, + model_path=MOCK_MODEL_PATH, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + + predictor = model.deploy(model_data_download_timeout=1800) + + assert model.vpc_config == MOCK_VPC_CONFIG + assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" + assert isinstance(predictor, TransformersLocalModePredictor) + assert builder.nb_instance_type == "ml.g5.24xlarge" + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + assert "HF_MODEL_ID" in model.env + with self.assertRaises(ValueError) as _: + model.deploy(mode=Mode.IN_PROCESS) + builder._prepare_for_mode.call_args_list[1].assert_called_once_with( + model_path=MOCK_MODEL_PATH, should_upload_artifacts=True + ) + + @patch( + "sagemaker.serve.builder.transformers_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) + @patch("sagemaker.serve.builder.transformers_builder._is_optimized", return_value=True) + def test_transformers_builder_optimized_sagemaker_endpoint_mode_no_s3_upload_success( + self, + mock_is_optimized, + mock_get_nb_instance, + mock_telemetry, + ): + # verify LOCAL_CONTAINER deploy + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, mode=Mode.LOCAL_CONTAINER, vpc_config=MOCK_VPC_CONFIG, + model_path=MOCK_MODEL_PATH, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() + + model.deploy(model_data_download_timeout=1800) + + # verify SAGEMAKER_ENDPOINT overwritten deploy + builder._original_deploy = MagicMock() + builder._prepare_for_mode.return_value = (None, {}) + + model.deploy(mode=Mode.SAGEMAKER_ENDPOINT, role="mock_role_arn") + + builder._prepare_for_mode.assert_called_once_with() + + @patch( + "sagemaker.serve.builder.transformers_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) + def test_transformers_builder_image_uri_override_success( + self, + mock_get_nb_instance, + mock_telemetry, + ): + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + image_uri=MOCK_IMAGE_CONFIG, ) builder._prepare_for_mode = MagicMock() @@ -87,7 +202,7 @@ def test_build_deploy_for_transformers_local_container_and_remote_container( builder.modes[str(Mode.LOCAL_CONTAINER)] = MagicMock() predictor = model.deploy(model_data_download_timeout=1800) - assert model.vpc_config == MOCK_VPC_CONFIG + assert builder.image_uri == MOCK_IMAGE_CONFIG assert builder.env_vars["MODEL_LOADING_TIMEOUT"] == "1800" assert isinstance(predictor, TransformersLocalModePredictor) @@ -100,3 +215,29 @@ def test_build_deploy_for_transformers_local_container_and_remote_container( with self.assertRaises(ValueError) as _: model.deploy(mode=Mode.IN_PROCESS) + + @patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers") + @patch( + "sagemaker.serve.builder.transformers_builder._get_nb_instance", + return_value="ml.g5.24xlarge", + ) + @patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.model_builder.get_huggingface_model_metadata", + return_value={}, + ) + def test_transformers_builder_empty_hf_md_defaults_to_transformers_success( + self, mock_model_md, mock_get_nb_instance, mock_telemetry, mock_build_for_transformers + ): + builder = ModelBuilder( + model=MOCK_MODEL_ID, + schema_builder=MOCK_SCHEMA_BUILDER, + mode=Mode.LOCAL_CONTAINER, + ) + + builder._prepare_for_mode = MagicMock() + builder._prepare_for_mode.side_effect = None + + builder.build() + + mock_build_for_transformers.assert_called_once() diff --git a/tests/unit/sagemaker/serve/constants.py b/tests/unit/sagemaker/serve/constants.py index db9dd623d8..3e776eaa46 100644 --- a/tests/unit/sagemaker/serve/constants.py +++ b/tests/unit/sagemaker/serve/constants.py @@ -15,3 +15,303 @@ MOCK_IMAGE_CONFIG = {"RepositoryAccessMode": "Vpc"} MOCK_VPC_CONFIG = {"Subnets": ["subnet-1234"], "SecurityGroupIds": ["sg123"]} +DEPLOYMENT_CONFIGS = [ + { + "ConfigName": "neuron-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentArgs": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "neuron-inference-budget", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentArgs": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "gpu-inference-budget", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentArgs": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, + { + "ConfigName": "gpu-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentArgs": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, + }, +] +NON_OPTIMIZED_DEPLOYMENT_CONFIG = { + "ConfigName": "neuron-inference", + "BenchmarkMetrics": [ + {"name": "Latency", "value": "100", "unit": "Tokens/S"}, + {"name": "Throughput", "value": "1867", "unit": "Tokens/S"}, + ], + "DeploymentArgs": { + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4" + ".0-gpu-py310-cu121-ubuntu20.04", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration" + "-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "InstanceType": "ml.p2.xlarge", + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 16384, + "NumberOfAcceleratorDevicesRequired": 1, + }, + }, +} +OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL = { + "DeploymentConfigName": "lmi-optimized", + "DeploymentArgs": { + "ImageUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "djl-inference:0.29.0-lmi11.0.0-cu124", + "ModelData": { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-1-70b/artifacts/inference-prepack/v2.0.0/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "ModelPackageArn": None, + "Environment": { + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + "InstanceType": "ml.g6.2xlarge", + "ComputeResourceRequirements": { + "MinMemoryRequiredInMb": 131072, + "NumberOfAcceleratorDevicesRequired": 1, + }, + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "AdditionalDataSources": { + "speculative_decoding": [ + { + "channel_name": "draft_model", + "provider": {"name": "JumpStart", "classification": "gated"}, + "artifact_version": "v1", + "hosting_eula_key": "fmhMetadata/eula/llama3_2Eula.txt", + "s3_data_source": { + "s3_uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/" + "inference-prepack/v1.0.0/", + "compression_type": "None", + "s3_data_type": "S3Prefix", + }, + } + ] + }, + }, + "AccelerationConfigs": [ + { + "type": "Compilation", + "enabled": False, + "diy_workflow_overrides": { + "gpu-lmi-trt": { + "enabled": False, + "reason": "TRT-LLM 0.11.0 in LMI v11 does not support llama 3.1", + } + }, + }, + { + "type": "Speculative-Decoding", + "enabled": True, + "diy_workflow_overrides": { + "gpu-lmi-trt": { + "enabled": False, + "reason": "LMI v11 does not support Speculative Decoding for TRT", + } + }, + }, + { + "type": "Quantization", + "enabled": False, + "diy_workflow_overrides": { + "gpu-lmi-trt": { + "enabled": False, + "reason": "TRT-LLM 0.11.0 in LMI v11 does not support llama 3.1", + } + }, + }, + ], + "BenchmarkMetrics": {"ml.g6.2xlarge": None}, +} +GATED_DRAFT_MODEL_CONFIG = { + "channel_name": "draft_model", + "provider": {"name": "JumpStart", "classification": "gated"}, + "artifact_version": "v1", + "hosting_eula_key": "fmhMetadata/eula/llama3_2Eula.txt", + "s3_data_source": { + "s3_uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/" + "inference-prepack/v1.0.0/", + "compression_type": "None", + "s3_data_type": "S3Prefix", + }, +} +NON_GATED_DRAFT_MODEL_CONFIG = { + "channel_name": "draft_model", + "s3_data_source": { + "compression_type": "None", + "s3_data_type": "S3Prefix", + "s3_uri": "s3://sagemaker-sd-models-beta-us-west-2/" + "sagemaker-speculative-decoding-llama3-small-v3/", + }, +} +CAMEL_CASE_ADDTL_DRAFT_MODEL_DATA_SOURCES = [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": "meta-textgeneration/meta-textgeneration-llama-3-2-1b/artifacts/" + "inference-prepack/v1.0.0/", + "CompressionType": "None", + "S3DataType": "S3Prefix", + }, + } +] diff --git a/tests/unit/sagemaker/serve/detector/test_dependency_manager.py b/tests/unit/sagemaker/serve/detector/test_dependency_manager.py index 491968dd25..52e9822e57 100644 --- a/tests/unit/sagemaker/serve/detector/test_dependency_manager.py +++ b/tests/unit/sagemaker/serve/detector/test_dependency_manager.py @@ -21,7 +21,7 @@ DEPENDENCY_LIST = [ "requests==2.26.0", - "numpy>=1.20.0", + "numpy==1.26.4", "pandas<=1.3.3", "matplotlib<3.5.0", "scikit-learn>0.24.1", @@ -34,7 +34,7 @@ EXPECTED_DEPENDENCY_MAP = { "requests": "==2.26.0", - "numpy": ">=1.20.0", + "numpy": "==1.26.4", "pandas": "<=1.3.3", "matplotlib": "<3.5.0", "scikit-learn": ">0.24.1", diff --git a/tests/unit/sagemaker/serve/detector/test_pickle_dependencies.py b/tests/unit/sagemaker/serve/detector/test_pickle_dependencies.py index 34cab8a526..ced9555fc5 100644 --- a/tests/unit/sagemaker/serve/detector/test_pickle_dependencies.py +++ b/tests/unit/sagemaker/serve/detector/test_pickle_dependencies.py @@ -93,13 +93,14 @@ def create_mock_modules(name, doc, file): # happy case def test_generate_requirements_exact_match(monkeypatch): - with patch("cloudpickle.load"), patch("tqdm.tqdm"), patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.run" - ) as subprocess_run, patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.Popen" - ) as subprocess_popen, patch( - "builtins.open" - ) as mocked_open, monkeypatch.context() as m: + with ( + patch("cloudpickle.load"), + patch("tqdm.tqdm"), + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.run") as subprocess_run, + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.Popen") as subprocess_popen, + patch("builtins.open") as mocked_open, + monkeypatch.context() as m, + ): mock_run_stdout = MagicMock() mock_run_stdout.stdout = json.dumps(INSTALLED_PKG_JSON).encode("utf-8") subprocess_run.return_value = mock_run_stdout @@ -147,13 +148,14 @@ def test_generate_requirements_exact_match(monkeypatch): def test_generate_requirements_txt_pruning_unused_packages(monkeypatch): - with patch("cloudpickle.load"), patch("tqdm.tqdm"), patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.run" - ) as subprocess_run, patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.Popen" - ) as subprocess_popen, patch( - "builtins.open" - ) as mocked_open, monkeypatch.context() as m: + with ( + patch("cloudpickle.load"), + patch("tqdm.tqdm"), + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.run") as subprocess_run, + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.Popen") as subprocess_popen, + patch("builtins.open") as mocked_open, + monkeypatch.context() as m, + ): mock_run_stdout = MagicMock() mock_run_stdout.stdout = json.dumps(INSTALLED_PKG_JSON_UNUSED).encode("utf-8") subprocess_run.return_value = mock_run_stdout @@ -201,13 +203,14 @@ def test_generate_requirements_txt_pruning_unused_packages(monkeypatch): def test_generate_requirements_txt_no_currently_used_packages(monkeypatch): - with patch("cloudpickle.load"), patch("tqdm.tqdm"), patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.run" - ) as subprocess_run, patch( - "sagemaker.serve.detector.pickle_dependencies.subprocess.Popen" - ) as subprocess_popen, patch( - "builtins.open" - ) as mocked_open, monkeypatch.context() as m: + with ( + patch("cloudpickle.load"), + patch("tqdm.tqdm"), + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.run") as subprocess_run, + patch("sagemaker.serve.detector.pickle_dependencies.subprocess.Popen") as subprocess_popen, + patch("builtins.open") as mocked_open, + monkeypatch.context() as m, + ): mock_run_stdout = MagicMock() mock_run_stdout.stdout = json.dumps([]).encode("utf-8") subprocess_run.return_value = mock_run_stdout diff --git a/tests/unit/sagemaker/serve/mode/test_in_process_mode.py b/tests/unit/sagemaker/serve/mode/test_in_process_mode.py new file mode 100644 index 0000000000..29d625dbbc --- /dev/null +++ b/tests/unit/sagemaker/serve/mode/test_in_process_mode.py @@ -0,0 +1,214 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +from unittest.mock import patch, Mock + +from sagemaker.serve.mode.in_process_mode import InProcessMode +from sagemaker.serve import SchemaBuilder +from sagemaker.serve.utils.exceptions import InProcessDeepPingException + + +mock_prompt = "Hello, I'm a language model," +mock_response = "Hello, I'm a language model, and I'm here to help you with your English." +mock_sample_input = {"inputs": mock_prompt, "parameters": {}} +mock_sample_output = [{"generated_text": mock_response}] +mock_model = "gpt2" + + +class TestInProcessMode(unittest.TestCase): + + @patch("sagemaker.serve.mode.in_process_mode.Path") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_load_happy_transformers(self, mock_session, mock_inference_spec, mock_path): + mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True + mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True + + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" + + mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) + in_process_mode = InProcessMode( + inference_spec=mock_inference_spec, + model=mock_model, + schema_builder=mock_schema_builder, + session=mock_session, + model_path="model_path", + env_vars={"key": "val"}, + ) + + res = in_process_mode.load(model_path="/tmp/model-builder/code/") + + self.assertEqual(res, "Dummy load") + self.assertEqual(in_process_mode.inference_spec, mock_inference_spec) + self.assertEqual(in_process_mode.schema_builder, mock_schema_builder) + self.assertEqual(in_process_mode.model_path, "model_path") + self.assertEqual(in_process_mode.env_vars, {"key": "val"}) + + @patch("sagemaker.serve.mode.in_process_mode.Path") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_load_happy_djl_serving(self, mock_session, mock_inference_spec, mock_path): + mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True + mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True + + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" + + mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) + in_process_mode = InProcessMode( + inference_spec=mock_inference_spec, + model=mock_model, + schema_builder=mock_schema_builder, + session=mock_session, + model_path="model_path", + env_vars={"key": "val"}, + ) + + res = in_process_mode.load(model_path="/tmp/model-builder/code/") + + self.assertEqual(res, "Dummy load") + self.assertEqual(in_process_mode.inference_spec, mock_inference_spec) + self.assertEqual(in_process_mode.schema_builder, mock_schema_builder) + self.assertEqual(in_process_mode.model_path, "model_path") + self.assertEqual(in_process_mode.env_vars, {"key": "val"}) + + @patch("sagemaker.serve.mode.in_process_mode.Path") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_load_ex(self, mock_session, mock_inference_spec, mock_path): + mock_path.return_value.exists.side_effect = lambda *args, **kwargs: False + mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: True + + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" + + mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) + in_process_mode = InProcessMode( + inference_spec=mock_inference_spec, + model=mock_model, + schema_builder=mock_schema_builder, + session=mock_session, + model_path="model_path", + ) + + self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/") + + mock_path.return_value.exists.side_effect = lambda *args, **kwargs: True + mock_path.return_value.is_dir.side_effect = lambda *args, **kwargs: False + + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" + mock_schema_builder = SchemaBuilder(mock_sample_input, mock_sample_output) + in_process_mode = InProcessMode( + inference_spec=mock_inference_spec, + model=mock_model, + schema_builder=mock_schema_builder, + session=mock_session, + model_path="model_path", + ) + + self.assertRaises(ValueError, in_process_mode.load, "/tmp/model-builder/code/") + + @patch("sagemaker.serve.mode.in_process_mode.logger") + @patch("sagemaker.base_predictor.PredictorBase") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_create_server_happy( + self, mock_session, mock_inference_spec, mock_predictor, mock_logger + ): + mock_start_serving = Mock() + mock_start_serving.side_effect = lambda *args, **kwargs: ( + True, + None, + ) + + mock_response = "Fake response" + mock_multi_model_server_deep_ping = Mock() + mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: ( + True, + mock_response, + ) + + in_process_mode = InProcessMode( + inference_spec=mock_inference_spec, + model=mock_model, + schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), + session=mock_session, + model_path="model_path", + ) + + in_process_mode._deep_ping = mock_multi_model_server_deep_ping + in_process_mode._start_serving = mock_start_serving + + in_process_mode.create_server(predictor=mock_predictor) + + mock_logger.info.assert_called_once_with("Waiting for fastapi server to start up...") + mock_logger.debug.assert_called_once_with( + "Ping health check has passed. Returned %s", str(mock_response) + ) + + @patch("sagemaker.base_predictor.PredictorBase") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_create_server_ex( + self, + mock_session, + mock_inference_spec, + mock_predictor, + ): + mock_start_serving = Mock() + mock_start_serving.side_effect = lambda *args, **kwargs: ( + True, + None, + ) + + mock_multi_model_server_deep_ping = Mock() + mock_multi_model_server_deep_ping.side_effect = lambda *args, **kwargs: ( + False, + None, + ) + + in_process_mode = InProcessMode( + inference_spec=mock_inference_spec, + model=mock_model, + schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), + session=mock_session, + model_path="model_path", + ) + + in_process_mode._deep_ping = mock_multi_model_server_deep_ping + in_process_mode._start_serving = mock_start_serving + + self.assertRaises(InProcessDeepPingException, in_process_mode.create_server, mock_predictor) + + @patch( + "sagemaker.serve.model_server.in_process_model_server.in_process_server.InProcessServing._stop_serving" + ) + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + @patch("sagemaker.session.Session") + def test_destroy_server( + self, + mock_session, + mock_inference_spec, + mock_stop_serving, + ): + in_process_mode = InProcessMode( + inference_spec=mock_inference_spec, + model=mock_model, + schema_builder=SchemaBuilder(mock_sample_input, mock_sample_output), + session=mock_session, + model_path="model_path", + ) + + in_process_mode.destroy_server() + + mock_stop_serving.assert_called() diff --git a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py index 154b6b7d95..819800ba46 100644 --- a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py +++ b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import os +from pathlib import Path from unittest.mock import patch, MagicMock, mock_open import pytest @@ -21,6 +22,7 @@ from sagemaker.serve import ModelServer from sagemaker.serve.model_format.mlflow.constants import ( MLFLOW_PYFUNC, + TENSORFLOW_SAVED_MODEL_NAME, ) from sagemaker.serve.model_format.mlflow.utils import ( _get_default_model_server_for_mlflow, @@ -30,11 +32,12 @@ _get_framework_version_from_requirements, _get_deployment_flavor, _get_python_version_from_parsed_mlflow_model_file, - _mlflow_input_is_local_path, _download_s3_artifacts, _select_container_for_mlflow_model, _validate_input_for_mlflow, _copy_directory_contents, + _move_contents, + _get_saved_model_path_for_tensorflow_and_keras_flavor, ) @@ -193,17 +196,6 @@ def test_get_python_version_from_parsed_mlflow_model_file(): _get_python_version_from_parsed_mlflow_model_file({}) -@patch("os.path.exists") -def test_mlflow_input_is_local_path(mock_path_exists): - valid_path = "/path/to/mlflow_model" - mock_path_exists.side_effect = lambda path: path == valid_path - - assert not _mlflow_input_is_local_path("s3://my_bucket/path/to/model") - assert not _mlflow_input_is_local_path("runs:/run-id/run/relative/path/to/model") - assert not _mlflow_input_is_local_path("/invalid/path") - assert _mlflow_input_is_local_path(valid_path) - - def test_download_s3_artifacts(): pass @@ -414,11 +406,71 @@ def test_select_container_for_mlflow_model_no_dlc_detected( ) +@patch("sagemaker.image_uris.retrieve") +@patch("sagemaker.serve.model_format.mlflow.utils._cast_to_compatible_version") +@patch("sagemaker.serve.model_format.mlflow.utils._get_framework_version_from_requirements") +@patch( + "sagemaker.serve.model_format.mlflow.utils._get_python_version_from_parsed_mlflow_model_file" +) +@patch("sagemaker.serve.model_format.mlflow.utils._get_all_flavor_metadata") +@patch("sagemaker.serve.model_format.mlflow.utils._generate_mlflow_artifact_path") +def test_select_container_for_mlflow_model_no_framework_version_detected( + mock_generate_mlflow_artifact_path, + mock_get_all_flavor_metadata, + mock_get_python_version_from_parsed_mlflow_model_file, + mock_get_framework_version_from_requirements, + mock_cast_to_compatible_version, + mock_image_uris_retrieve, +): + mlflow_model_src_path = "/path/to/mlflow_model" + deployment_flavor = "pytorch" + region = "us-west-2" + instance_type = "ml.m5.xlarge" + + mock_requirements_path = "/path/to/requirements.txt" + mock_metadata_path = "/path/to/mlmodel" + mock_flavor_metadata = {"pytorch": {"some_key": "some_value"}} + mock_python_version = "3.8.6" + + mock_generate_mlflow_artifact_path.side_effect = lambda path, artifact: ( + mock_requirements_path if artifact == "requirements.txt" else mock_metadata_path + ) + mock_get_all_flavor_metadata.return_value = mock_flavor_metadata + mock_get_python_version_from_parsed_mlflow_model_file.return_value = mock_python_version + mock_get_framework_version_from_requirements.return_value = None + + with pytest.raises( + ValueError, + match="Unable to auto detect framework version. Please provide framework " + "pytorch as part of the requirements.txt file for deployment flavor " + "pytorch", + ): + _select_container_for_mlflow_model( + mlflow_model_src_path, deployment_flavor, region, instance_type + ) + + mock_generate_mlflow_artifact_path.assert_any_call( + mlflow_model_src_path, "requirements.txt" + ) + mock_generate_mlflow_artifact_path.assert_any_call(mlflow_model_src_path, "MLmodel") + mock_get_all_flavor_metadata.assert_called_once_with(mock_metadata_path) + mock_get_framework_version_from_requirements.assert_called_once_with( + deployment_flavor, mock_requirements_path + ) + mock_cast_to_compatible_version.assert_not_called() + mock_image_uris_retrieve.assert_not_called() + + def test_validate_input_for_mlflow(): - _validate_input_for_mlflow(ModelServer.TORCHSERVE) + _validate_input_for_mlflow(ModelServer.TORCHSERVE, "pytorch") with pytest.raises(ValueError): - _validate_input_for_mlflow(ModelServer.DJL_SERVING) + _validate_input_for_mlflow(ModelServer.DJL_SERVING, "pytorch") + + +def test_validate_input_for_mlflow_non_supported_flavor_with_tf_serving(): + with pytest.raises(ValueError): + _validate_input_for_mlflow(ModelServer.TENSORFLOW_SERVING, "pytorch") @patch("sagemaker.serve.model_format.mlflow.utils.shutil.copy2") @@ -472,3 +524,68 @@ def test_copy_directory_contents_handles_same_src_dst( mock_os_walk.assert_not_called() mock_os_makedirs.assert_not_called() mock_shutil_copy2.assert_not_called() + + +@patch("os.path.abspath") +@patch("os.walk") +def test_get_saved_model_path_found(mock_os_walk, mock_os_abspath): + mock_os_walk.return_value = [ + ("/root/folder1", ("subfolder",), ()), + ("/root/folder1/subfolder", (), (TENSORFLOW_SAVED_MODEL_NAME,)), + ] + expected_path = "/root/folder1/subfolder" + mock_os_abspath.return_value = expected_path + + # Call the function + result = _get_saved_model_path_for_tensorflow_and_keras_flavor("/root/folder1") + + # Assertions + mock_os_walk.assert_called_once_with("/root/folder1") + mock_os_abspath.assert_called_once_with("/root/folder1/subfolder") + assert result == expected_path + + +@patch("os.path.abspath") +@patch("os.walk") +def test_get_saved_model_path_not_found(mock_os_walk, mock_os_abspath): + mock_os_walk.return_value = [ + ("/root/folder2", ("subfolder",), ()), + ("/root/folder2/subfolder", (), ("not_saved_model.pb",)), + ] + + result = _get_saved_model_path_for_tensorflow_and_keras_flavor("/root/folder2") + + mock_os_walk.assert_called_once_with("/root/folder2") + mock_os_abspath.assert_not_called() + assert result is None + + +@patch("sagemaker.serve.model_format.mlflow.utils.shutil.move") +@patch("sagemaker.serve.model_format.mlflow.utils.Path.iterdir") +@patch("sagemaker.serve.model_format.mlflow.utils.Path.mkdir") +def test_move_contents_handles_same_src_dst(mock_mkdir, mock_iterdir, mock_shutil_move): + src_dir = "/fake/source/dir" + dest_dir = "/fake/source/./dir" + + mock_iterdir.return_value = [] + + _move_contents(src_dir, dest_dir) + + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + mock_shutil_move.assert_not_called() + + +@patch("sagemaker.serve.model_format.mlflow.utils.shutil.move") +@patch("sagemaker.serve.model_format.mlflow.utils.Path.iterdir") +@patch("sagemaker.serve.model_format.mlflow.utils.Path.mkdir") +def test_move_contents_with_actual_files(mock_mkdir, mock_iterdir, mock_shutil_move): + src_dir = Path("/fake/source/dir") + dest_dir = Path("/fake/destination/dir") + + file_path = src_dir / "testfile.txt" + mock_iterdir.return_value = [file_path] + + _move_contents(src_dir, dest_dir) + + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + mock_shutil_move.assert_called_once_with(str(file_path), str(dest_dir / "testfile.txt")) diff --git a/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py index caa8884186..aa99e1971c 100644 --- a/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py @@ -13,12 +13,11 @@ from __future__ import absolute_import from unittest import TestCase -from unittest.mock import Mock, PropertyMock, patch, mock_open, call +from unittest.mock import Mock, PropertyMock, patch, mock_open from sagemaker.serve.model_server.djl_serving.prepare import ( _copy_jumpstart_artifacts, _create_dir_structure, - _move_to_code_dir, _extract_js_resource, ) from tests.unit.sagemaker.serve.model_server.constants import ( @@ -31,7 +30,7 @@ MOCK_INVALID_MODEL_DATA_DICT, ) -MOCK_DJL_JUMPSTART_GLOBED_RESOURCES = ["./inference.py", "./serving.properties", "./config.json"] +MOCK_DJL_JUMPSTART_GLOBED_RESOURCES = ["./config.json"] class DjlPrepareTests(TestCase): @@ -53,8 +52,8 @@ def test_create_dir_structure_from_new(self, mock_path, mock_disk_usage, mock_di mock_disk_space.assert_called_once_with(mock_model_path) mock_disk_usage.assert_called_once() - self.assertEquals(ret_model_path, mock_model_path) - self.assertEquals(ret_code_dir, mock_code_dir) + self.assertEqual(ret_model_path, mock_model_path) + self.assertEqual(ret_code_dir, mock_code_dir) @patch("sagemaker.serve.model_server.djl_serving.prepare.Path") def test_create_dir_structure_invalid_path(self, mock_path): @@ -66,117 +65,68 @@ def test_create_dir_structure_invalid_path(self, mock_path): with self.assertRaises(ValueError) as context: _create_dir_structure(mock_model_path) - self.assertEquals("model_dir is not a valid directory", str(context.exception)) + self.assertEqual("model_dir is not a valid directory", str(context.exception)) @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") - @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") - @patch( - "sagemaker.serve.model_server.djl_serving.prepare._read_existing_serving_properties", - return_value={}, - ) - @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") @patch("builtins.open", new_callable=mock_open, read_data="data") @patch("json.load", return_value={}) def test_prepare_djl_js_resources_for_jumpstart_uncompressed_str( self, mock_load, mock_open, - mock_move_to_code_dir, - mock_existing_props, - mock_tmpdir, mock_s3_downloader, ): mock_code_dir = Mock() - mock_config_json_file = Mock() - mock_config_json_file.is_file.return_value = True - mock_code_dir.joinpath.return_value = mock_config_json_file - mock_s3_downloader_obj = Mock() mock_s3_downloader.return_value = mock_s3_downloader_obj - mock_tmpdir_obj = Mock() - mock_js_dir = Mock() - mock_js_dir.return_value = MOCK_TMP_DIR - type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=mock_js_dir) - type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) - mock_tmpdir.return_value = mock_tmpdir_obj - - existing_properties, hf_model_config, success = _copy_jumpstart_artifacts( + _copy_jumpstart_artifacts( MOCK_UNCOMPRESSED_MODEL_DATA_STR, MOCK_JUMPSTART_ID, mock_code_dir ) mock_s3_downloader_obj.download.assert_called_once_with( - MOCK_UNCOMPRESSED_MODEL_DATA_STR, MOCK_TMP_DIR + MOCK_UNCOMPRESSED_MODEL_DATA_STR, mock_code_dir ) - mock_move_to_code_dir.assert_called_once_with(MOCK_TMP_DIR, mock_code_dir) - mock_code_dir.joinpath.assert_called_once_with("config.json") - self.assertEqual(existing_properties, {}) - self.assertEqual(hf_model_config, {}) - self.assertEqual(success, True) @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") - @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") - @patch( - "sagemaker.serve.model_server.djl_serving.prepare._read_existing_serving_properties", - return_value={}, - ) - @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") @patch("builtins.open", new_callable=mock_open, read_data="data") @patch("json.load", return_value={}) def test_prepare_djl_js_resources_for_jumpstart_uncompressed_dict( self, mock_load, mock_open, - mock_move_to_code_dir, - mock_existing_props, - mock_tmpdir, mock_s3_downloader, ): mock_code_dir = Mock() - mock_config_json_file = Mock() - mock_config_json_file.is_file.return_value = True - mock_code_dir.joinpath.return_value = mock_config_json_file - mock_s3_downloader_obj = Mock() mock_s3_downloader.return_value = mock_s3_downloader_obj - mock_tmpdir_obj = Mock() - mock_js_dir = Mock() - mock_js_dir.return_value = MOCK_TMP_DIR - type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=mock_js_dir) - type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) - mock_tmpdir.return_value = mock_tmpdir_obj - - existing_properties, hf_model_config, success = _copy_jumpstart_artifacts( + _copy_jumpstart_artifacts( MOCK_UNCOMPRESSED_MODEL_DATA_DICT, MOCK_JUMPSTART_ID, mock_code_dir ) mock_s3_downloader_obj.download.assert_called_once_with( - MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT, MOCK_TMP_DIR + MOCK_UNCOMPRESSED_MODEL_DATA_STR_FOR_DICT, mock_code_dir ) - mock_move_to_code_dir.assert_called_once_with(MOCK_TMP_DIR, mock_code_dir) - mock_code_dir.joinpath.assert_called_once_with("config.json") - self.assertEqual(existing_properties, {}) - self.assertEqual(hf_model_config, {}) - self.assertEqual(success, True) - @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") - @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") + @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") + @patch("builtins.open", new_callable=mock_open, read_data="data") + @patch("json.load", return_value={}) def test_prepare_djl_js_resources_for_jumpstart_invalid_model_data( - self, mock_move_to_code_dir, mock_tmpdir + self, + mock_load, + mock_open, + mock_s3_downloader, ): mock_code_dir = Mock() - mock_tmpdir_obj = Mock() - type(mock_tmpdir_obj).__enter__ = PropertyMock(return_value=Mock()) - type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) - mock_tmpdir.return_value = mock_tmpdir_obj + mock_s3_downloader_obj = Mock() + mock_s3_downloader.return_value = mock_s3_downloader_obj with self.assertRaises(ValueError) as context: _copy_jumpstart_artifacts( MOCK_INVALID_MODEL_DATA_DICT, MOCK_JUMPSTART_ID, mock_code_dir ) - assert not mock_move_to_code_dir.called self.assertTrue( "JumpStart model data compression format is unsupported" in str(context.exception) ) @@ -184,27 +134,17 @@ def test_prepare_djl_js_resources_for_jumpstart_invalid_model_data( @patch("sagemaker.serve.model_server.djl_serving.prepare.S3Downloader") @patch("sagemaker.serve.model_server.djl_serving.prepare._extract_js_resource") @patch("sagemaker.serve.model_server.djl_serving.prepare._tmpdir") - @patch( - "sagemaker.serve.model_server.djl_serving.prepare._read_existing_serving_properties", - return_value={}, - ) - @patch("sagemaker.serve.model_server.djl_serving.prepare._move_to_code_dir") @patch("builtins.open", new_callable=mock_open, read_data="data") @patch("json.load", return_value={}) def test_prepare_djl_js_resources_for_jumpstart_compressed_str( self, mock_load, mock_open, - mock_move_to_code_dir, - mock_existing_props, mock_tmpdir, mock_extract_js_resource, mock_s3_downloader, ): mock_code_dir = Mock() - mock_config_json_file = Mock() - mock_config_json_file.is_file.return_value = True - mock_code_dir.joinpath.return_value = mock_config_json_file mock_s3_downloader_obj = Mock() mock_s3_downloader.return_value = mock_s3_downloader_obj @@ -216,41 +156,14 @@ def test_prepare_djl_js_resources_for_jumpstart_compressed_str( type(mock_tmpdir_obj).__exit__ = PropertyMock(return_value=Mock()) mock_tmpdir.return_value = mock_tmpdir_obj - existing_properties, hf_model_config, success = _copy_jumpstart_artifacts( - MOCK_COMPRESSED_MODEL_DATA_STR, MOCK_JUMPSTART_ID, mock_code_dir - ) + _copy_jumpstart_artifacts(MOCK_COMPRESSED_MODEL_DATA_STR, MOCK_JUMPSTART_ID, mock_code_dir) mock_s3_downloader_obj.download.assert_called_once_with( MOCK_COMPRESSED_MODEL_DATA_STR, MOCK_TMP_DIR ) - mock_extract_js_resource.assert_called_with(MOCK_TMP_DIR, MOCK_JUMPSTART_ID) - mock_move_to_code_dir.assert_called_once_with(MOCK_TMP_DIR, mock_code_dir) - mock_code_dir.joinpath.assert_called_once_with("config.json") - self.assertEqual(existing_properties, {}) - self.assertEqual(hf_model_config, {}) - self.assertEqual(success, True) - - @patch("sagemaker.serve.model_server.djl_serving.prepare.Path") - @patch("sagemaker.serve.model_server.djl_serving.prepare.shutil") - def test_move_to_code_dir_success(self, mock_shutil, mock_path): - mock_path_obj = Mock() - mock_js_model_resources = Mock() - mock_js_model_resources.glob.return_value = MOCK_DJL_JUMPSTART_GLOBED_RESOURCES - mock_path_obj.joinpath.return_value = mock_js_model_resources - mock_path.return_value = mock_path_obj - - mock_js_model_dir = "" - mock_code_dir = Mock() - _move_to_code_dir(mock_js_model_dir, mock_code_dir) - - mock_path_obj.joinpath.assert_called_once_with("model") - - expected_moves = [ - call("./inference.py", mock_code_dir), - call("./serving.properties", mock_code_dir), - call("./config.json", mock_code_dir), - ] - mock_shutil.move.assert_has_calls(expected_moves) + mock_extract_js_resource.assert_called_once_with( + MOCK_TMP_DIR, mock_code_dir, MOCK_JUMPSTART_ID + ) @patch("sagemaker.serve.model_server.djl_serving.prepare.Path") @patch("sagemaker.serve.model_server.djl_serving.prepare.tarfile") @@ -268,8 +181,9 @@ def test_extract_js_resources_success(self, mock_tarfile, mock_path): mock_tarfile.open.return_value = mock_tar_obj js_model_dir = "" - _extract_js_resource(js_model_dir, MOCK_JUMPSTART_ID) + code_dir = Mock() + _extract_js_resource(js_model_dir, code_dir, MOCK_JUMPSTART_ID) mock_path.assert_called_once_with(js_model_dir) mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz") - mock_resource_obj.extractall.assert_called_once_with(path=js_model_dir, filter="data") + mock_resource_obj.extractall.assert_called_once_with(path=code_dir, filter="data") diff --git a/tests/unit/sagemaker/serve/model_server/in_process_model_server/test_app.py b/tests/unit/sagemaker/serve/model_server/in_process_model_server/test_app.py new file mode 100644 index 0000000000..65ba80c370 --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/in_process_model_server/test_app.py @@ -0,0 +1,91 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +import pytest + +from unittest.mock import patch, Mock +from sagemaker.serve.model_server.in_process_model_server.app import InProcessServer +from tests.integ.sagemaker.serve.constants import ( + PYTHON_VERSION_IS_NOT_310, +) + +mock_model_id = "mock_model_id" + + +class TestAppInProcessServer(unittest.TestCase): + @pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these tests are to test the serving components of our feature", + ) + @patch("sagemaker.serve.model_server.in_process_model_server.app.threading") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + def test_in_process_server_init(self, mock_inference_spec, mock_threading): + mock_generator = Mock() + mock_generator.side_effect = None + + in_process_server = InProcessServer(inference_spec=mock_inference_spec) + in_process_server._generator = mock_generator + + @pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these test are to test the serving components of our feature", + ) + @patch("sagemaker.serve.model_server.in_process_model_server.app.logger") + @patch("sagemaker.serve.model_server.in_process_model_server.app.threading") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + def test_start_server(self, mock_inference_spec, mock_threading, mock_logger): + mock_generator = Mock() + mock_generator.side_effect = None + mock_thread = Mock() + mock_threading.Thread.return_value = mock_thread + + in_process_server = InProcessServer(inference_spec=mock_inference_spec) + in_process_server._generator = mock_generator + + in_process_server.start_server() + + mock_logger.info.assert_called() + mock_thread.start.assert_called() + + @pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these test are to test the serving components of our feature", + ) + @patch("sagemaker.serve.model_server.in_process_model_server.app.asyncio") + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + def test_start_run_async_in_thread(self, mock_inference_spec, mock_asyncio): + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" + + mock_loop = Mock() + mock_asyncio.new_event_loop.side_effect = lambda: mock_loop + + in_process_server = InProcessServer(inference_spec=mock_inference_spec) + in_process_server._start_run_async_in_thread() + + mock_asyncio.set_event_loop.assert_called_once_with(mock_loop) + mock_loop.run_until_complete.assert_called() + + @patch("sagemaker.serve.spec.inference_spec.InferenceSpec") + async def test_serve(self, mock_inference_spec): + mock_inference_spec.load.side_effect = lambda *args, **kwargs: "Dummy load" + + mock_server = Mock() + + in_process_server = InProcessServer(inference_spec=mock_inference_spec) + in_process_server.server = mock_server + + await in_process_server._serve() + + mock_server.serve.assert_called() diff --git a/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py b/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py index 895ed3907f..567a72182a 100644 --- a/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/multi_model_server/test_multi_model_server_prepare.py @@ -12,13 +12,67 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +from pathlib import PosixPath +import platform from unittest import TestCase from unittest.mock import Mock, patch +import numpy as np + from sagemaker.serve.model_server.multi_model_server.prepare import _create_dir_structure +from sagemaker.serve.model_server.multi_model_server.server import ( + LocalMultiModelServer, +) + +CPU_TF_IMAGE = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04" +) +MODEL_PATH = "model_path" +MODEL_REPO = f"{MODEL_PATH}/1" +ENV_VAR = {"KEY": "VALUE"} +PAYLOAD = np.random.rand(3, 4).astype(dtype=np.float32) +DTYPE = "TYPE_FP32" +SECRET_KEY = "secret_key" +INFER_RESPONSE = {"outputs": [{"name": "output_name"}]} + class MultiModelServerPrepareTests(TestCase): + def test_start_invoke_destroy_local_multi_model_server(self): + mock_container = Mock() + mock_docker_client = Mock() + mock_docker_client.containers.run.return_value = mock_container + + local_multi_model_server = LocalMultiModelServer() + mock_schema_builder = Mock() + mock_schema_builder.input_serializer.serialize.return_value = PAYLOAD + local_multi_model_server.schema_builder = mock_schema_builder + + local_multi_model_server._start_serving( + client=mock_docker_client, + model_path=MODEL_PATH, + secret_key=SECRET_KEY, + env_vars=ENV_VAR, + image=CPU_TF_IMAGE, + ) + + mock_docker_client.containers.run.assert_called_once_with( + CPU_TF_IMAGE, + "serve", + network_mode="host", + detach=True, + auto_remove=True, + volumes={PosixPath("model_path/code"): {"bind": "/opt/ml/model/", "mode": "rw"}}, + environment={ + "KEY": "VALUE", + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": "secret_key", + "LOCAL_PYTHON": platform.python_version(), + }, + ) + @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_disk_space") @patch("sagemaker.serve.model_server.multi_model_server.prepare._check_docker_disk_usage") @patch("sagemaker.serve.model_server.multi_model_server.prepare.Path") @@ -37,8 +91,8 @@ def test_create_dir_structure_from_new(self, mock_path, mock_disk_usage, mock_di mock_disk_space.assert_called_once_with(mock_model_path) mock_disk_usage.assert_called_once() - self.assertEquals(ret_model_path, mock_model_path) - self.assertEquals(ret_code_dir, mock_code_dir) + self.assertEqual(ret_model_path, mock_model_path) + self.assertEqual(ret_code_dir, mock_code_dir) @patch("sagemaker.serve.model_server.multi_model_server.prepare.Path") def test_create_dir_structure_invalid_path(self, mock_path): @@ -50,4 +104,4 @@ def test_create_dir_structure_invalid_path(self, mock_path): with self.assertRaises(ValueError) as context: _create_dir_structure(mock_model_path) - self.assertEquals("model_dir is not a valid directory", str(context.exception)) + self.assertEqual("model_dir is not a valid directory", str(context.exception)) diff --git a/tests/unit/sagemaker/serve/model_server/tei/__init__.py b/tests/unit/sagemaker/serve/model_server/tei/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/serve/model_server/tei/test_server.py b/tests/unit/sagemaker/serve/model_server/tei/test_server.py new file mode 100644 index 0000000000..47399c1fad --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/tei/test_server.py @@ -0,0 +1,151 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from pathlib import PosixPath +from unittest import TestCase +from unittest.mock import Mock, patch + +from docker.types import DeviceRequest +from sagemaker.serve.model_server.tei.server import LocalTeiServing, SageMakerTeiServing +from sagemaker.serve.utils.exceptions import LocalModelInvocationException + +TEI_IMAGE = ( + "246618743249.dkr.ecr.us-west-2.amazonaws.com/tei:2.0.1-tei1.2.3-gpu-py310-cu122-ubuntu22.04" +) +MODEL_PATH = "model_path" +ENV_VAR = {"KEY": "VALUE"} +PAYLOAD = { + "inputs": { + "sourceSentence": "How cute your dog is!", + "sentences": ["The mitochondria is the powerhouse of the cell.", "Your dog is so cute."], + } +} +S3_URI = "s3://mock_model_data_uri" +SECRET_KEY = "secret_key" +INFER_RESPONSE = [] + + +class TeiServerTests(TestCase): + @patch("sagemaker.serve.model_server.tei.server.requests") + def test_start_invoke_destroy_local_tei_server(self, mock_requests): + mock_container = Mock() + mock_docker_client = Mock() + mock_docker_client.containers.run.return_value = mock_container + + local_tei_server = LocalTeiServing() + mock_schema_builder = Mock() + mock_schema_builder.input_serializer.serialize.return_value = PAYLOAD + local_tei_server.schema_builder = mock_schema_builder + + local_tei_server._start_tei_serving( + client=mock_docker_client, + model_path=MODEL_PATH, + secret_key=SECRET_KEY, + image=TEI_IMAGE, + env_vars=ENV_VAR, + ) + + mock_docker_client.containers.run.assert_called_once_with( + TEI_IMAGE, + shm_size="2G", + device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])], + network_mode="host", + detach=True, + auto_remove=True, + volumes={PosixPath("model_path/code"): {"bind": "/opt/ml/model/", "mode": "rw"}}, + environment={ + "HF_HOME": "/opt/ml/model/", + "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", + "KEY": "VALUE", + "SAGEMAKER_SERVE_SECRET_KEY": "secret_key", + }, + ) + + mock_response = Mock() + mock_requests.post.side_effect = lambda *args, **kwargs: mock_response + mock_response.content = INFER_RESPONSE + + res = local_tei_server._invoke_tei_serving( + request=PAYLOAD, content_type="application/json", accept="application/json" + ) + + self.assertEqual(res, INFER_RESPONSE) + + def test_tei_deep_ping(self): + mock_predictor = Mock() + mock_response = Mock() + mock_schema_builder = Mock() + + mock_predictor.predict.side_effect = lambda *args, **kwargs: mock_response + mock_schema_builder.sample_input = PAYLOAD + + local_tei_server = LocalTeiServing() + local_tei_server.schema_builder = mock_schema_builder + res = local_tei_server._tei_deep_ping(mock_predictor) + + self.assertEqual(res, (True, mock_response)) + + def test_tei_deep_ping_invoke_ex(self): + mock_predictor = Mock() + mock_schema_builder = Mock() + + mock_predictor.predict.side_effect = lambda *args, **kwargs: exec( + 'raise(ValueError("422 Client Error: Unprocessable Entity for url:"))' + ) + mock_schema_builder.sample_input = PAYLOAD + + local_tei_server = LocalTeiServing() + local_tei_server.schema_builder = mock_schema_builder + + self.assertRaises( + LocalModelInvocationException, lambda: local_tei_server._tei_deep_ping(mock_predictor) + ) + + def test_tei_deep_ping_ex(self): + mock_predictor = Mock() + + mock_predictor.predict.side_effect = lambda *args, **kwargs: Exception() + + local_tei_server = LocalTeiServing() + res = local_tei_server._tei_deep_ping(mock_predictor) + + self.assertEqual(res, (False, None)) + + @patch("sagemaker.serve.model_server.tei.server.S3Uploader") + def test_upload_artifacts_sagemaker_tei_server(self, mock_uploader): + mock_session = Mock() + mock_uploader.upload.side_effect = ( + lambda *args, **kwargs: "s3://sagemaker-us-west-2-123456789123/tei-2024-05-20-16-05-36-027/code" + ) + + s3_upload_path, env_vars = SageMakerTeiServing()._upload_tei_artifacts( + model_path=MODEL_PATH, + sagemaker_session=mock_session, + s3_model_data_url=S3_URI, + image=TEI_IMAGE, + should_upload_artifacts=True, + ) + + mock_uploader.upload.assert_called_once() + self.assertEqual( + s3_upload_path, + { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": "s3://sagemaker-us-west-2-123456789123/tei-2024-05-20-16-05-36-027/code/", + } + }, + ) + self.assertIsNotNone(env_vars) diff --git a/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_prepare.py b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_prepare.py new file mode 100644 index 0000000000..9915b19649 --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_prepare.py @@ -0,0 +1,116 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest import TestCase +from unittest.mock import Mock, patch, mock_open +import pytest + +from sagemaker.serve.model_server.tensorflow_serving.prepare import prepare_for_tf_serving + +MODEL_PATH = "/path/to/your/model/dir" +SHARED_LIBS = ["/path/to/shared/libs"] +DEPENDENCIES = {"dependencies": "requirements.txt"} +INFERENCE_SPEC = Mock() +IMAGE_URI = "mock_image_uri" +XGB_1P_IMAGE_URI = "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.7-1" +INFERENCE_SPEC.prepare = Mock(return_value=None) + +SECRET_KEY = "secret-key" + +mock_session = Mock() + + +class PrepareForTensorflowServingTests(TestCase): + def setUp(self): + INFERENCE_SPEC.reset_mock() + + @patch("builtins.open", new_callable=mock_open, read_data=b"{}") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents") + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare." + "_get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._MetaData") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.shutil") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.Path") + def test_prepare_happy( + self, + mock_path, + mock_shutil, + mock_capture_dependencies, + mock_generate_secret_key, + mock_compute_hash, + mock_metadata, + mock_get_saved_model_path, + mock_move_contents, + mock_open, + ): + + mock_path_instance = mock_path.return_value + mock_path_instance.exists.return_value = True + mock_path_instance.joinpath.return_value = Mock() + mock_get_saved_model_path.return_value = MODEL_PATH + "/1/" + + mock_generate_secret_key.return_value = SECRET_KEY + + secret_key = prepare_for_tf_serving( + model_path=MODEL_PATH, + shared_libs=SHARED_LIBS, + dependencies=DEPENDENCIES, + ) + + mock_path_instance.mkdir.assert_not_called() + self.assertEqual(secret_key, SECRET_KEY) + + @patch("builtins.open", new_callable=mock_open, read_data=b"{}") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._move_contents") + @patch( + "sagemaker.serve.model_server.tensorflow_serving.prepare." + "_get_saved_model_path_for_tensorflow_and_keras_flavor" + ) + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare._MetaData") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.compute_hash") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.generate_secret_key") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.capture_dependencies") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.shutil") + @patch("sagemaker.serve.model_server.tensorflow_serving.prepare.Path") + def test_prepare_saved_model_not_found( + self, + mock_path, + mock_shutil, + mock_capture_dependencies, + mock_generate_secret_key, + mock_compute_hash, + mock_metadata, + mock_get_saved_model_path, + mock_move_contents, + mock_open, + ): + + mock_path_instance = mock_path.return_value + mock_path_instance.exists.return_value = True + mock_path_instance.joinpath.return_value = Mock() + mock_get_saved_model_path.return_value = None + + with pytest.raises( + ValueError, match="SavedModel is not found for Tensorflow or Keras flavor." + ): + prepare_for_tf_serving( + model_path=MODEL_PATH, + shared_libs=SHARED_LIBS, + dependencies=DEPENDENCIES, + ) diff --git a/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py new file mode 100644 index 0000000000..b9cce13dbb --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/tensorflow_serving/test_tf_server.py @@ -0,0 +1,101 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from pathlib import PosixPath +import platform +from unittest import TestCase +from unittest.mock import Mock, patch, ANY + +import numpy as np + +from sagemaker.serve.model_server.tensorflow_serving.server import ( + LocalTensorflowServing, + SageMakerTensorflowServing, +) + +CPU_TF_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:2.14.1-cpu" +MODEL_PATH = "model_path" +MODEL_REPO = f"{MODEL_PATH}/1" +ENV_VAR = {"KEY": "VALUE"} +_SHM_SIZE = "2G" +PAYLOAD = np.random.rand(3, 4).astype(dtype=np.float32) +S3_URI = "s3://mock_model_data_uri" +DTYPE = "TYPE_FP32" +SECRET_KEY = "secret_key" + +INFER_RESPONSE = {"outputs": [{"name": "output_name"}]} + + +class TensorflowservingServerTests(TestCase): + def test_start_invoke_destroy_local_tensorflow_serving_server(self): + mock_container = Mock() + mock_docker_client = Mock() + mock_docker_client.containers.run.return_value = mock_container + + local_tensorflow_server = LocalTensorflowServing() + mock_schema_builder = Mock() + mock_schema_builder.input_serializer.serialize.return_value = PAYLOAD + local_tensorflow_server.schema_builder = mock_schema_builder + + local_tensorflow_server._start_tensorflow_serving( + client=mock_docker_client, + model_path=MODEL_PATH, + secret_key=SECRET_KEY, + env_vars=ENV_VAR, + image=CPU_TF_IMAGE, + ) + + mock_docker_client.containers.run.assert_called_once_with( + CPU_TF_IMAGE, + "serve", + detach=True, + auto_remove=True, + network_mode="host", + volumes={PosixPath("model_path"): {"bind": "/opt/ml/model", "mode": "rw"}}, + environment={ + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": "secret_key", + "LOCAL_PYTHON": platform.python_version(), + "KEY": "VALUE", + }, + ) + + @patch("sagemaker.serve.model_server.tensorflow_serving.server.platform") + @patch("sagemaker.serve.model_server.tensorflow_serving.server.upload") + def test_upload_artifacts_sagemaker_triton_server(self, mock_upload, mock_platform): + mock_session = Mock() + mock_platform.python_version.return_value = "3.8" + mock_upload.side_effect = lambda session, repo, bucket, prefix: ( + S3_URI + if session == mock_session and repo == MODEL_PATH and bucket == "mock_model_data_uri" + else None + ) + + ( + s3_upload_path, + env_vars, + ) = SageMakerTensorflowServing()._upload_tensorflow_serving_artifacts( + model_path=MODEL_PATH, + sagemaker_session=mock_session, + secret_key=SECRET_KEY, + s3_model_data_url=S3_URI, + image=CPU_TF_IMAGE, + should_upload_artifacts=True, + ) + + mock_upload.assert_called_once_with(mock_session, MODEL_PATH, "mock_model_data_uri", ANY) + self.assertEqual(s3_upload_path, S3_URI) + self.assertEqual(env_vars.get("SAGEMAKER_SERVE_SECRET_KEY"), SECRET_KEY) + self.assertEqual(env_vars.get("LOCAL_PYTHON"), "3.8") diff --git a/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py index 88d109831d..ed94f10ce9 100644 --- a/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py @@ -50,8 +50,8 @@ def test_create_dir_structure_from_new(self, mock_path, mock_disk_usage, mock_di mock_disk_space.assert_called_once_with(mock_model_path) mock_disk_usage.assert_called_once() - self.assertEquals(ret_model_path, mock_model_path) - self.assertEquals(ret_code_dir, mock_code_dir) + self.assertEqual(ret_model_path, mock_model_path) + self.assertEqual(ret_code_dir, mock_code_dir) @patch("sagemaker.serve.model_server.tgi.prepare.Path") def test_create_dir_structure_invalid_path(self, mock_path): @@ -63,7 +63,7 @@ def test_create_dir_structure_invalid_path(self, mock_path): with self.assertRaises(ValueError) as context: _create_dir_structure(mock_model_path) - self.assertEquals("model_dir is not a valid directory", str(context.exception)) + self.assertEqual("model_dir is not a valid directory", str(context.exception)) @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") @patch("builtins.open", read_data="data") diff --git a/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_serving.py b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_serving.py new file mode 100644 index 0000000000..33371fc584 --- /dev/null +++ b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_serving.py @@ -0,0 +1,283 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest import TestCase +from unittest.mock import Mock, patch +from pathlib import Path +from sagemaker.serve.model_server.tgi.server import LocalTgiServing, SageMakerTgiServing + +MOCK_IMAGE = "mock image" +MOCK_MODEL_PATH = "mock model path" +MOCK_SECRET_KEY = "mock secret key" +MOCK_ENV_VARS = {"mock key": "mock value"} +MOCK_SAGEMAKER_SESSION = Mock() +MOCK_S3_MODEL_DATA_URL = "mock s3 path" +MOCK_MODEL_DATA_URL = "mock model data url" + +EXPECTED_MODE_DIR_BINDING = "/opt/ml/model/" +EXPECTED_SHM_SIZE = "2G" +EXPECTED_UPDATED_ENV_VARS = { + "HF_HOME": "/opt/ml/model/", + "HUGGINGFACE_HUB_CACHE": "/opt/ml/model/", + "mock key": "mock value", +} +EXPECTED_MODEL_DATA = { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": MOCK_MODEL_DATA_URL + "/", + } +} + + +class TestLocalTgiServing(TestCase): + def test_tgi_serving_runs_container_non_jumpstart_success(self): + # WHERE + mock_container_client = Mock() + mock_container = Mock() + mock_container_client.containers.run.return_value = mock_container + localTgiServing = LocalTgiServing() + + # WHEN + localTgiServing._start_tgi_serving( + mock_container_client, + MOCK_IMAGE, + MOCK_MODEL_PATH, + MOCK_SECRET_KEY, + MOCK_ENV_VARS, + False, + ) + + # THEN + mock_container_client.containers.run.assert_called_once_with( + MOCK_IMAGE, + shm_size=EXPECTED_SHM_SIZE, + device_requests=[ + { + "Driver": "", + "Count": -1, + "DeviceIDs": [], + "Capabilities": [["gpu"]], + "Options": {}, + } + ], + network_mode="host", + detach=True, + auto_remove=True, + volumes={ + Path(MOCK_MODEL_PATH).joinpath("code"): { + "bind": EXPECTED_MODE_DIR_BINDING, + "mode": "rw", + } + }, + environment=EXPECTED_UPDATED_ENV_VARS, + ) + assert localTgiServing.container == mock_container + + def test_tgi_serving_runs_container_jumpstart_success(self): + # WHERE + mock_container_client = Mock() + mock_container = Mock() + mock_container_client.containers.run.return_value = mock_container + localTgiServing = LocalTgiServing() + + # WHEN + localTgiServing._start_tgi_serving( + mock_container_client, MOCK_IMAGE, MOCK_MODEL_PATH, MOCK_SECRET_KEY, MOCK_ENV_VARS, True + ) + + # THEN + mock_container_client.containers.run.assert_called_once_with( + MOCK_IMAGE, + ["--model-id", EXPECTED_MODE_DIR_BINDING], + shm_size=EXPECTED_SHM_SIZE, + device_requests=[ + { + "Driver": "", + "Count": -1, + "DeviceIDs": [], + "Capabilities": [["gpu"]], + "Options": {}, + } + ], + network_mode="host", + detach=True, + auto_remove=True, + volumes={ + Path(MOCK_MODEL_PATH).joinpath("code"): { + "bind": EXPECTED_MODE_DIR_BINDING, + "mode": "rw", + } + }, + environment=MOCK_ENV_VARS, + ) + assert localTgiServing.container == mock_container + + +class TestSageMakerTgiServing(TestCase): + + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") + @patch("sagemaker.serve.model_server.tgi.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tgi.server.fw_utils") + @patch("sagemaker.serve.model_server.tgi.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tgi.server.s3_path_join") + @patch("sagemaker.serve.model_server.tgi.server.S3Uploader") + def test_tgi_serving_upload_tgi_artifacts_s3_url_passed_success( + self, + mock_s3_uploader, + mock_s3_path_join, + mock_determine_bucket_and_prefix, + mock_fw_utils, + mock_parse_s3_url, + mock_is_s3_uri, + ): + # WHERE + mock_is_s3_uri.return_value = False + mock_parse_s3_url.return_value = ("mock_bucket_1", "mock_prefix_1") + mock_fw_utils.model_code_key_prefix.return_value = "mock_code_key_prefix" + mock_determine_bucket_and_prefix.return_value = ("mock_bucket_2", "mock_prefix_2") + mock_s3_path_join.return_value = "mock_s3_location" + mock_s3_uploader.upload.return_value = MOCK_MODEL_DATA_URL + + sagemakerTgiServing = SageMakerTgiServing() + + # WHEN + ret_model_data, ret_env_vars = sagemakerTgiServing._upload_tgi_artifacts( + MOCK_MODEL_PATH, + MOCK_SAGEMAKER_SESSION, + False, + MOCK_S3_MODEL_DATA_URL, + MOCK_IMAGE, + MOCK_ENV_VARS, + True, + ) + + # THEN + mock_is_s3_uri.assert_called_once_with(MOCK_MODEL_PATH) + mock_parse_s3_url.assert_called_once_with(url=MOCK_S3_MODEL_DATA_URL) + mock_fw_utils.model_code_key_prefix.assert_called_once_with( + "mock_prefix_1", None, MOCK_IMAGE + ) + mock_determine_bucket_and_prefix.assert_called_once_with( + bucket="mock_bucket_1", + key_prefix="mock_code_key_prefix", + sagemaker_session=MOCK_SAGEMAKER_SESSION, + ) + mock_s3_path_join.assert_called_once_with("s3://", "mock_bucket_2", "mock_prefix_2", "code") + mock_s3_uploader.upload.assert_called_once_with( + f"{MOCK_MODEL_PATH}/code", "mock_s3_location", None, MOCK_SAGEMAKER_SESSION + ) + assert ret_model_data == EXPECTED_MODEL_DATA + assert ret_env_vars == EXPECTED_UPDATED_ENV_VARS + + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") + @patch("sagemaker.serve.model_server.tgi.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tgi.server.fw_utils") + @patch("sagemaker.serve.model_server.tgi.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tgi.server.s3_path_join") + @patch("sagemaker.serve.model_server.tgi.server.S3Uploader") + def test_tgi_serving_upload_tgi_artifacts_jumpstart_success( + self, + mock_s3_uploader, + mock_s3_path_join, + mock_determine_bucket_and_prefix, + mock_fw_utils, + mock_parse_s3_url, + mock_is_s3_uri, + ): + # WHERE + mock_is_s3_uri.return_value = False + mock_parse_s3_url.return_value = ("mock_bucket_1", "mock_prefix_1") + mock_fw_utils.model_code_key_prefix.return_value = "mock_code_key_prefix" + mock_determine_bucket_and_prefix.return_value = ("mock_bucket_2", "mock_prefix_2") + mock_s3_path_join.return_value = "mock_s3_location" + mock_s3_uploader.upload.return_value = MOCK_MODEL_DATA_URL + + sagemakerTgiServing = SageMakerTgiServing() + + # WHEN + ret_model_data, ret_env_vars = sagemakerTgiServing._upload_tgi_artifacts( + MOCK_MODEL_PATH, + MOCK_SAGEMAKER_SESSION, + True, + MOCK_S3_MODEL_DATA_URL, + MOCK_IMAGE, + MOCK_ENV_VARS, + True, + ) + + # THEN + mock_is_s3_uri.assert_called_once_with(MOCK_MODEL_PATH) + mock_parse_s3_url.assert_called_once_with(url=MOCK_S3_MODEL_DATA_URL) + mock_fw_utils.model_code_key_prefix.assert_called_once_with( + "mock_prefix_1", None, MOCK_IMAGE + ) + mock_determine_bucket_and_prefix.assert_called_once_with( + bucket="mock_bucket_1", + key_prefix="mock_code_key_prefix", + sagemaker_session=MOCK_SAGEMAKER_SESSION, + ) + mock_s3_path_join.assert_called_once_with("s3://", "mock_bucket_2", "mock_prefix_2", "code") + mock_s3_uploader.upload.assert_called_once_with( + f"{MOCK_MODEL_PATH}/code", "mock_s3_location", None, MOCK_SAGEMAKER_SESSION + ) + assert ret_model_data == EXPECTED_MODEL_DATA + assert ret_env_vars == {} + + @patch("sagemaker.serve.model_server.tgi.server._is_s3_uri") + @patch("sagemaker.serve.model_server.tgi.server.parse_s3_url") + @patch("sagemaker.serve.model_server.tgi.server.fw_utils") + @patch("sagemaker.serve.model_server.tgi.server.determine_bucket_and_prefix") + @patch("sagemaker.serve.model_server.tgi.server.s3_path_join") + @patch("sagemaker.serve.model_server.tgi.server.S3Uploader") + def test_tgi_serving_upload_tgi_artifacts( + self, + mock_s3_uploader, + mock_s3_path_join, + mock_determine_bucket_and_prefix, + mock_fw_utils, + mock_parse_s3_url, + mock_is_s3_uri, + ): + # WHERE + mock_is_s3_uri.return_value = True + + sagemakerTgiServing = SageMakerTgiServing() + + # WHEN + ret_model_data, ret_env_vars = sagemakerTgiServing._upload_tgi_artifacts( + MOCK_MODEL_PATH, + MOCK_SAGEMAKER_SESSION, + False, + MOCK_S3_MODEL_DATA_URL, + MOCK_IMAGE, + MOCK_ENV_VARS, + True, + ) + + # THEN + mock_is_s3_uri.assert_called_once_with(MOCK_MODEL_PATH) + assert not mock_parse_s3_url.called + assert not mock_fw_utils.model_code_key_prefix.called + assert not mock_determine_bucket_and_prefix.called + assert not mock_s3_path_join.called + assert not mock_s3_uploader.upload.called + assert ret_model_data == { + "S3DataSource": { + "CompressionType": "None", + "S3DataType": "S3Prefix", + "S3Uri": MOCK_MODEL_PATH + "/", + } + } + assert ret_env_vars == EXPECTED_UPDATED_ENV_VARS diff --git a/tests/unit/sagemaker/serve/model_server/triton/test_server.py b/tests/unit/sagemaker/serve/model_server/triton/test_server.py index c80c4296e7..3f571424ed 100644 --- a/tests/unit/sagemaker/serve/model_server/triton/test_server.py +++ b/tests/unit/sagemaker/serve/model_server/triton/test_server.py @@ -172,6 +172,7 @@ def test_upload_artifacts_sagemaker_triton_server(self, mock_upload, mock_platfo secret_key=SECRET_KEY, s3_model_data_url=S3_URI, image=GPU_TRITON_IMAGE, + should_upload_artifacts=True, ) mock_upload.assert_called_once_with(mock_session, MODEL_REPO, "mock_model_data_uri", ANY) diff --git a/tests/unit/sagemaker/serve/utils/test_hardware_detector.py b/tests/unit/sagemaker/serve/utils/test_hardware_detector.py index d383f95809..58839bfc50 100644 --- a/tests/unit/sagemaker/serve/utils/test_hardware_detector.py +++ b/tests/unit/sagemaker/serve/utils/test_hardware_detector.py @@ -21,7 +21,7 @@ REGION = "us-west-2" VALID_INSTANCE_TYPE = "ml.g5.48xlarge" INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge" -EXPECTED_INSTANCE_GPU_INFO = (8, 196608) +EXPECTED_INSTANCE_GPU_INFO = (8, 183104) MIB_CONVERSION_FACTOR = 0.00000095367431640625 MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer @@ -39,7 +39,7 @@ def test_get_gpu_info_success(sagemaker_session, boto_session): "MemoryInfo": {"SizeInMiB": 24576}, } ], - "TotalGpuMemoryInMiB": 196608, + "TotalGpuMemoryInMiB": 183104, }, } ] diff --git a/tests/unit/sagemaker/serve/utils/test_lineage_utils.py b/tests/unit/sagemaker/serve/utils/test_lineage_utils.py new file mode 100644 index 0000000000..99da766031 --- /dev/null +++ b/tests/unit/sagemaker/serve/utils/test_lineage_utils.py @@ -0,0 +1,418 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import call + +import datetime +import pytest +from botocore.exceptions import ClientError +from mock import Mock, patch +from sagemaker import Session +from sagemaker.lineage.artifact import ArtifactSummary, Artifact +from sagemaker.lineage.query import LineageSourceEnum + +from sagemaker.serve.utils.lineage_constants import ( + TRACKING_SERVER_CREATION_TIME_FORMAT, + MLFLOW_RUN_ID, + MLFLOW_MODEL_PACKAGE_PATH, + MLFLOW_S3_PATH, + MLFLOW_LOCAL_PATH, + LINEAGE_POLLER_MAX_TIMEOUT_SECS, + MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, + CONTRIBUTED_TO, + MLFLOW_REGISTRY_PATH, +) +from sagemaker.serve.utils.lineage_utils import ( + _load_artifact_by_source_uri, + _poll_lineage_artifact, + _get_mlflow_model_path_type, + _create_mlflow_model_path_lineage_artifact, + _add_association_between_artifacts, + _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact, + _maintain_lineage_tracking_for_mlflow_model, +) + + +@patch("sagemaker.lineage.artifact.Artifact.list") +def test_load_artifact_by_source_uri(mock_artifact_list): + source_uri = "s3://mybucket/mymodel" + sagemaker_session = Mock(spec=Session) + + mock_artifact_1 = Mock(spec=ArtifactSummary) + mock_artifact_1.artifact_type = LineageSourceEnum.MODEL_DATA.value + mock_artifact_2 = Mock(spec=ArtifactSummary) + mock_artifact_2.artifact_type = LineageSourceEnum.IMAGE.value + mock_artifacts = [mock_artifact_1, mock_artifact_2] + mock_artifact_list.return_value = mock_artifacts + + result = _load_artifact_by_source_uri( + source_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value + ) + + mock_artifact_list.assert_called_once_with( + source_uri=source_uri, sagemaker_session=sagemaker_session + ) + assert result == mock_artifact_1 + + +@patch("sagemaker.lineage.artifact.Artifact.list") +def test_load_artifact_by_source_uri_no_match(mock_artifact_list): + source_uri = "s3://mybucket/mymodel" + sagemaker_session = Mock(spec=Session) + + mock_artifact_1 = Mock(spec=ArtifactSummary) + mock_artifact_1.artifact_type = LineageSourceEnum.IMAGE.value + mock_artifact_2 = Mock(spec=ArtifactSummary) + mock_artifact_2.artifact_type = LineageSourceEnum.IMAGE.value + mock_artifacts = [mock_artifact_1, mock_artifact_2] + mock_artifact_list.return_value = mock_artifacts + + result = _load_artifact_by_source_uri( + source_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value + ) + + mock_artifact_list.assert_called_once_with( + source_uri=source_uri, sagemaker_session=sagemaker_session + ) + assert result is None + + +@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri") +def test_poll_lineage_artifact_found(mock_load_artifact): + s3_uri = "s3://mybucket/mymodel" + sagemaker_session = Mock(spec=Session) + mock_artifact = Mock(spec=ArtifactSummary) + + with patch("time.time") as mock_time: + mock_time.return_value = 0 + + mock_load_artifact.return_value = mock_artifact + + result = _poll_lineage_artifact( + s3_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session + ) + + assert result == mock_artifact + mock_load_artifact.assert_has_calls( + [ + call(s3_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value), + ] + ) + + +@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri") +def test_poll_lineage_artifact_not_found(mock_load_artifact): + s3_uri = "s3://mybucket/mymodel" + artifact_type = LineageSourceEnum.MODEL_DATA.value + sagemaker_session = Mock(spec=Session) + + with patch("time.time") as mock_time: + mock_time_values = [0.0, 1.0, LINEAGE_POLLER_MAX_TIMEOUT_SECS + 1.0] + mock_time.side_effect = mock_time_values + + with patch("time.sleep"): + mock_load_artifact.side_effect = [None, None, None] + + result = _poll_lineage_artifact(s3_uri, artifact_type, sagemaker_session) + + assert result is None + + +@pytest.mark.parametrize( + "mlflow_model_path, expected_output", + [ + ("runs:/abc123/my-model", MLFLOW_RUN_ID), + ("models:/my-model/1", MLFLOW_REGISTRY_PATH), + ( + "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model-package", + MLFLOW_MODEL_PACKAGE_PATH, + ), + ("s3://my-bucket/path/to/model", MLFLOW_S3_PATH), + ], +) +def test_get_mlflow_model_path_type_valid(mlflow_model_path, expected_output): + result = _get_mlflow_model_path_type(mlflow_model_path) + assert result == expected_output + + +@patch("os.path.exists") +def test_get_mlflow_model_path_type_valid_local_path(mock_path_exists): + valid_path = "/path/to/mlflow_model" + mock_path_exists.side_effect = lambda path: path == valid_path + result = _get_mlflow_model_path_type(valid_path) + assert result == MLFLOW_LOCAL_PATH + + +def test_get_mlflow_model_path_type_invalid(): + invalid_path = "invalid_path" + with pytest.raises(ValueError, match=f"Invalid MLflow model path: {invalid_path}"): + _get_mlflow_model_path_type(invalid_path) + + +@patch("sagemaker.serve.utils.lineage_utils._get_mlflow_model_path_type") +@patch("sagemaker.lineage.artifact.Artifact.create") +def test_create_mlflow_model_path_lineage_artifact_success( + mock_artifact_create, mock_get_mlflow_path_type +): + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_source_types = [dict(SourceIdType="Custom", Value="ModelBuilderInputModelData")] + sagemaker_session = Mock(spec=Session) + mock_artifact = Mock(spec=Artifact) + mock_get_mlflow_path_type.return_value = "mlflow_run_id" + mock_artifact_create.return_value = mock_artifact + + result = _create_mlflow_model_path_lineage_artifact(mlflow_model_path, sagemaker_session) + + assert result == mock_artifact + mock_get_mlflow_path_type.assert_called_once_with(mlflow_model_path) + mock_artifact_create.assert_called_once_with( + source_uri=mlflow_model_path, + source_types=mock_source_types, + artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE, + artifact_name="mlflow_run_id", + properties={"model_builder_input_model_data_type": "mlflow_run_id"}, + sagemaker_session=sagemaker_session, + ) + + +@patch("sagemaker.serve.utils.lineage_utils._get_mlflow_model_path_type") +@patch("sagemaker.lineage.artifact.Artifact.create") +def test_create_mlflow_model_path_lineage_artifact_validation_exception( + mock_artifact_create, mock_get_mlflow_path_type +): + mlflow_model_path = "runs:/Ab12Cd34/my-model" + sagemaker_session = Mock(spec=Session) + mock_get_mlflow_path_type.return_value = "mlflow_run_id" + mock_artifact_create.side_effect = ClientError( + error_response={"Error": {"Code": "ValidationException"}}, operation_name="CreateArtifact" + ) + + result = _create_mlflow_model_path_lineage_artifact(mlflow_model_path, sagemaker_session) + + assert result is None + + +@patch("sagemaker.serve.utils.lineage_utils._get_mlflow_model_path_type") +@patch("sagemaker.lineage.artifact.Artifact.create") +def test_create_mlflow_model_path_lineage_artifact_other_exception( + mock_artifact_create, mock_get_mlflow_path_type +): + mlflow_model_path = "runs:/Ab12Cd34/my-model" + sagemaker_session = Mock(spec=Session) + mock_get_mlflow_path_type.return_value = "mlflow_run_id" + mock_artifact_create.side_effect = ClientError( + error_response={"Error": {"Code": "SomeOtherException"}}, operation_name="CreateArtifact" + ) + + with pytest.raises(ClientError): + _create_mlflow_model_path_lineage_artifact(mlflow_model_path, sagemaker_session) + + +@patch("sagemaker.serve.utils.lineage_utils._create_mlflow_model_path_lineage_artifact") +@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri") +def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_existing( + mock_load_artifact, mock_create_artifact +): + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) + mock_creation_time = datetime.datetime(2024, 5, 15, 0, 0, 0) + sagemaker_session = Mock(spec=Session) + mock_sagemaker_client = Mock() + mock_describe_response = {"CreationTime": mock_creation_time} + mock_sagemaker_client.describe_mlflow_tracking_server.return_value = mock_describe_response + sagemaker_session.sagemaker_client = mock_sagemaker_client + mock_source_types_to_match = [ + "ModelBuilderInputModelData", + mock_tracking_server_arn, + mock_creation_time.strftime(TRACKING_SERVER_CREATION_TIME_FORMAT), + ] + mock_artifact_summary = Mock(spec=ArtifactSummary) + mock_load_artifact.return_value = mock_artifact_summary + + result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( + mlflow_model_path, sagemaker_session, mock_tracking_server_arn + ) + + assert result == mock_artifact_summary + mock_load_artifact.assert_called_once_with( + mlflow_model_path, + sagemaker_session, + mock_source_types_to_match, + ) + mock_create_artifact.assert_not_called() + + +@patch("sagemaker.serve.utils.lineage_utils._create_mlflow_model_path_lineage_artifact") +@patch("sagemaker.serve.utils.lineage_utils._load_artifact_by_source_uri") +def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_create( + mock_load_artifact, mock_create_artifact +): + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) + mock_creation_time = datetime.datetime(2024, 5, 15, 0, 0, 0) + sagemaker_session = Mock(spec=Session) + mock_sagemaker_client = Mock() + mock_describe_response = {"CreationTime": mock_creation_time} + mock_sagemaker_client.describe_mlflow_tracking_server.return_value = mock_describe_response + sagemaker_session.sagemaker_client = mock_sagemaker_client + mock_source_types_to_match = [ + "ModelBuilderInputModelData", + mock_tracking_server_arn, + mock_creation_time.strftime(TRACKING_SERVER_CREATION_TIME_FORMAT), + ] + mock_artifact = Mock(spec=Artifact) + mock_load_artifact.return_value = None + mock_create_artifact.return_value = mock_artifact + + result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact( + mlflow_model_path, sagemaker_session, mock_tracking_server_arn + ) + + assert result == mock_artifact + mock_load_artifact.assert_called_once_with( + mlflow_model_path, + sagemaker_session, + mock_source_types_to_match, + ) + mock_create_artifact.assert_called_once_with( + mlflow_model_path, sagemaker_session, mock_source_types_to_match + ) + + +@patch("sagemaker.lineage.association.Association.create") +def test_add_association_between_artifacts_success(mock_association_create): + mlflow_model_path_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/123" + autogenerated_model_data_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/456" + sagemaker_session = Mock(spec=Session) + + _add_association_between_artifacts( + mlflow_model_path_artifact_arn, + autogenerated_model_data_artifact_arn, + sagemaker_session, + ) + + mock_association_create.assert_called_once_with( + source_arn=mlflow_model_path_artifact_arn, + destination_arn=autogenerated_model_data_artifact_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + +@patch("sagemaker.lineage.association.Association.create") +def test_add_association_between_artifacts_validation_exception(mock_association_create): + mlflow_model_path_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/123" + autogenerated_model_data_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/456" + sagemaker_session = Mock(spec=Session) + mock_association_create.side_effect = ClientError( + error_response={"Error": {"Code": "ValidationException"}}, + operation_name="CreateAssociation", + ) + + _add_association_between_artifacts( + mlflow_model_path_artifact_arn, + autogenerated_model_data_artifact_arn, + sagemaker_session, + ) + + +@patch("sagemaker.lineage.association.Association.create") +def test_add_association_between_artifacts_other_exception(mock_association_create): + mlflow_model_path_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/123" + autogenerated_model_data_artifact_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/456" + sagemaker_session = Mock(spec=Session) + mock_association_create.side_effect = ClientError( + error_response={"Error": {"Code": "SomeOtherException"}}, operation_name="CreateAssociation" + ) + + with pytest.raises(ClientError): + _add_association_between_artifacts( + mlflow_model_path_artifact_arn, + autogenerated_model_data_artifact_arn, + sagemaker_session, + ) + + +@patch("sagemaker.serve.utils.lineage_utils._poll_lineage_artifact") +@patch( + "sagemaker.serve.utils.lineage_utils._retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact" +) +@patch("sagemaker.serve.utils.lineage_utils._add_association_between_artifacts") +def test_maintain_lineage_tracking_for_mlflow_model_success( + mock_add_association, mock_retrieve_create_artifact, mock_poll_artifact +): + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) + s3_upload_path = "s3://mybucket/path/to/model" + sagemaker_session = Mock(spec=Session) + mock_model_data_artifact = Mock(spec=ArtifactSummary) + mock_mlflow_model_artifact = Mock(spec=Artifact) + mock_poll_artifact.return_value = mock_model_data_artifact + mock_retrieve_create_artifact.return_value = mock_mlflow_model_artifact + + _maintain_lineage_tracking_for_mlflow_model( + mlflow_model_path, s3_upload_path, sagemaker_session, mock_tracking_server_arn + ) + + mock_poll_artifact.assert_called_once_with( + s3_uri=s3_upload_path, + artifact_type=LineageSourceEnum.MODEL_DATA.value, + sagemaker_session=sagemaker_session, + ) + mock_retrieve_create_artifact.assert_called_once_with( + mlflow_model_path=mlflow_model_path, + tracking_server_arn=mock_tracking_server_arn, + sagemaker_session=sagemaker_session, + ) + mock_add_association.assert_called_once_with( + mlflow_model_path_artifact_arn=mock_mlflow_model_artifact.artifact_arn, + autogenerated_model_data_artifact_arn=mock_model_data_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ) + + +@patch("sagemaker.serve.utils.lineage_utils._poll_lineage_artifact") +@patch( + "sagemaker.serve.utils.lineage_utils._retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact" +) +@patch("sagemaker.serve.utils.lineage_utils._add_association_between_artifacts") +def test_maintain_lineage_tracking_for_mlflow_model_no_model_data_artifact( + mock_add_association, mock_retrieve_create_artifact, mock_poll_artifact +): + mlflow_model_path = "runs:/Ab12Cd34/my-model" + mock_tracking_server_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/test" + ) + s3_upload_path = "s3://mybucket/path/to/model" + sagemaker_session = Mock(spec=Session) + mock_poll_artifact.return_value = None + mock_retrieve_create_artifact.return_value = None + + _maintain_lineage_tracking_for_mlflow_model( + mlflow_model_path, s3_upload_path, sagemaker_session, mock_tracking_server_arn + ) + + mock_poll_artifact.assert_called_once_with( + s3_uri=s3_upload_path, + artifact_type=LineageSourceEnum.MODEL_DATA.value, + sagemaker_session=sagemaker_session, + ) + mock_retrieve_create_artifact.assert_not_called() + mock_add_association.assert_not_called() diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py new file mode 100644 index 0000000000..b392b255da --- /dev/null +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -0,0 +1,591 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +from unittest.mock import Mock, patch + +import pytest + +from sagemaker.enums import Tag +from sagemaker.serve.utils.optimize_utils import ( + _generate_optimized_model, + _update_environment_variables, + _is_image_compatible_with_optimization_job, + _extract_speculative_draft_model_provider, + _extracts_and_validates_speculative_model_source, + _is_s3_uri, + _generate_additional_model_data_sources, + _generate_channel_name, + _extract_optimization_config_and_env, + _is_optimized, + _custom_speculative_decoding, + _is_inferentia_or_trainium, + _is_draft_model_gated, + _deployment_config_contains_draft_model, + _jumpstart_speculative_decoding, +) +from tests.unit.sagemaker.serve.constants import ( + GATED_DRAFT_MODEL_CONFIG, + NON_GATED_DRAFT_MODEL_CONFIG, + OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL, + NON_OPTIMIZED_DEPLOYMENT_CONFIG, +) + +mock_optimization_job_output = { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:312206380606:optimization-job/" + "modelbuilderjob-3cbf9c40b63c455d85b60033f9a01691", + "OptimizationJobStatus": "COMPLETED", + "OptimizationJobName": "modelbuilderjob-3cbf9c40b63c455d85b60033f9a01691", + "ModelSource": { + "S3": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/" + } + }, + "OptimizationEnvironment": { + "ENDPOINT_SERVER_TIMEOUT": "3600", + "HF_MODEL_ID": "/opt/ml/model", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "SAGEMAKER_PROGRAM": "inference.py", + }, + "DeploymentInstanceType": "ml.g5.2xlarge", + "OptimizationConfigs": [ + { + "ModelQuantizationConfig": { + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124", + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + } + } + ], + "OutputConfig": {"S3OutputLocation": "s3://quicksilver-model-data/llama-3-8b/quantized-1/"}, + "OptimizationOutput": { + "RecommendedInferenceImage": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124" + }, + "RoleArn": "arn:aws:iam::312206380606:role/service-role/AmazonSageMaker-ExecutionRole-20240116T151132", + "StoppingCondition": {"MaxRuntimeInSeconds": 36000}, + "ResponseMetadata": { + "RequestId": "a95253d5-c045-4708-8aac-9f0d327515f7", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "a95253d5-c045-4708-8aac-9f0d327515f7", + "content-type": "application/x-amz-json-1.1", + "content-length": "1371", + "date": "Fri, 21 Jun 2024 04:27:42 GMT", + }, + "RetryAttempts": 0, + }, +} + + +@pytest.mark.parametrize( + "instance, expected", + [ + ("ml.trn1.2xlarge", True), + ("ml.inf2.xlarge", True), + ("ml.c7gd.4xlarge", False), + ], +) +def test_is_inferentia_or_trainium(instance, expected): + assert _is_inferentia_or_trainium(instance) == expected + + +@pytest.mark.parametrize( + "image_uri, expected", + [ + ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124", + True, + ), + ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-neuronx-sdk2.18.2", + True, + ), + ( + None, + True, + ), + ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:" + "2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04", + False, + ), + (None, True), + ], +) +def test_is_image_compatible_with_optimization_job(image_uri, expected): + assert _is_image_compatible_with_optimization_job(image_uri) == expected + + +def test_generate_optimized_model(): + pysdk_model = Mock() + pysdk_model.model_data = { + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/" + } + } + + optimized_model = _generate_optimized_model(pysdk_model, mock_optimization_job_output) + + assert ( + optimized_model.image_uri + == mock_optimization_job_output["OptimizationOutput"]["RecommendedInferenceImage"] + ) + assert ( + optimized_model.model_data["S3DataSource"]["S3Uri"] + == mock_optimization_job_output["OutputConfig"]["S3OutputLocation"] + ) + assert optimized_model.instance_type == mock_optimization_job_output["DeploymentInstanceType"] + pysdk_model.add_tags.assert_called_once_with( + { + "Key": Tag.OPTIMIZATION_JOB_NAME, + "Value": mock_optimization_job_output["OptimizationJobName"], + } + ) + + +def test_is_optimized(): + model = Mock() + + model._tags = {"Key": Tag.OPTIMIZATION_JOB_NAME} + assert _is_optimized(model) is True + + model._tags = [{"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER}] + assert _is_optimized(model) is True + + model._tags = [{"Key": Tag.FINE_TUNING_MODEL_PATH}] + assert _is_optimized(model) is False + + +@pytest.mark.parametrize( + "env, new_env, output_env", + [ + ({"a": "1"}, {"b": "2"}, {"a": "1", "b": "2"}), + (None, {"b": "2"}, {"b": "2"}), + ({"a": "1"}, None, {"a": "1"}), + (None, None, None), + ], +) +def test_update_environment_variables(env, new_env, output_env): + assert _update_environment_variables(env, new_env) == output_env + + +@pytest.mark.parametrize( + "speculative_decoding_config, expected_model_provider", + [ + ({"ModelProvider": "SageMaker"}, "sagemaker"), + ({"ModelProvider": "Custom"}, "custom"), + ({"ModelSource": "s3://"}, "custom"), + ({"ModelProvider": "JumpStart"}, "jumpstart"), + ({"ModelProvider": "asdf"}, "auto"), + ({"ModelProvider": "Auto"}, "auto"), + (None, None), + ], +) +def test_extract_speculative_draft_model_provider( + speculative_decoding_config, expected_model_provider +): + assert ( + _extract_speculative_draft_model_provider(speculative_decoding_config) + == expected_model_provider + ) + + +def test_extract_speculative_draft_model_s3_uri(): + res = _extracts_and_validates_speculative_model_source({"ModelSource": "s3://"}) + assert res == "s3://" + + +def test_extract_speculative_draft_model_s3_uri_ex(): + with pytest.raises(ValueError): + _extracts_and_validates_speculative_model_source({"ModelSource": None}) + + +def test_generate_channel_name(): + assert _generate_channel_name(None) is not None + + additional_model_data_sources = _generate_additional_model_data_sources( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", True + ) + + assert _generate_channel_name(additional_model_data_sources) == "channel_name" + + +def test_generate_additional_model_data_sources(): + model_source = _generate_additional_model_data_sources( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", True + ) + + assert model_source == [ + { + "ChannelName": "channel_name", + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"AcceptEula": True}, + }, + } + ] + + model_source = _generate_additional_model_data_sources( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", "channel_name", False + ) + + assert model_source == [ + { + "ChannelName": "channel_name", + "S3DataSource": { + "S3Uri": "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + }, + } + ] + + +@pytest.mark.parametrize( + "s3_uri, expected", + [ + ( + "s3://jumpstart-private-cache-alpha-us-west-2/meta-textgeneration/" + "meta-textgeneration-llama-3-8b/artifacts/inference-prepack/v1.0.1/", + True, + ), + ("invalid://", False), + ], +) +def test_is_s3_uri(s3_uri, expected): + assert _is_s3_uri(s3_uri) == expected + + +@pytest.mark.parametrize( + "draft_model_config, expected", + [ + (GATED_DRAFT_MODEL_CONFIG, True), + (NON_GATED_DRAFT_MODEL_CONFIG, False), + ], +) +def test_is_draft_model_gated(draft_model_config, expected): + assert _is_draft_model_gated(draft_model_config) is expected + + +@pytest.mark.parametrize( + ( + "quantization_config, compilation_config, sharding_config, expected_config, " + "expected_quant_env, expected_compilation_env, expected_sharding_env" + ), + [ + ( + None, + { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + None, + { + "ModelCompilationConfig": { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + }, + None, + { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, + None, + ), + ( + { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + None, + None, + { + "ModelQuantizationConfig": { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + }, + { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, + None, + None, + ), + ( + None, + None, + { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + { + "ModelShardingConfig": { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + }, + None, + None, + { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, + ), + (None, None, None, None, None, None, None), + ], +) +def test_extract_optimization_config_and_env( + quantization_config, + compilation_config, + sharding_config, + expected_config, + expected_quant_env, + expected_compilation_env, + expected_sharding_env, +): + assert _extract_optimization_config_and_env( + quantization_config, compilation_config, sharding_config + ) == ( + expected_config, + expected_quant_env, + expected_compilation_env, + expected_sharding_env, + ) + + +@pytest.mark.parametrize( + "deployment_config", + [ + (OPTIMIZED_DEPLOYMENT_CONFIG_WITH_GATED_DRAFT_MODEL, True), + (NON_OPTIMIZED_DEPLOYMENT_CONFIG, False), + (None, False), + ], +) +def deployment_config_contains_draft_model(deployment_config, expected): + assert _deployment_config_contains_draft_model(deployment_config) + + +class TestJumpStartSpeculativeDecodingConfig(unittest.TestCase): + + @patch("sagemaker.model.Model") + def test_with_no_js_model_id(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = {"ModelSource": "JumpStart"} + + with self.assertRaises(ValueError) as _: + _jumpstart_speculative_decoding(mock_model, speculative_decoding_config) + + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket", + return_value="js_gated_content_bucket", + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket", + return_value="js_content_bucket", + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs", + return_value=Mock(), + ) + @patch("sagemaker.model.Model") + def test_with_gated_js_model( + self, + mock_model, + mock_model_specs, + mock_js_content_bucket, + mock_js_gated_content_bucket, + ): + mock_sagemaker_session = Mock() + mock_sagemaker_session.boto_region_name = "us-west-2" + + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "JumpStart", + "ModelID": "meta-textgeneration-llama-3-2-1b", + "AcceptEula": True, + } + + mock_model_specs.return_value.to_json.return_value = { + "gated_bucket": True, + "hosting_prepacked_artifact_key": "hosting_prepacked_artifact_key", + } + + _jumpstart_speculative_decoding( + mock_model, speculative_decoding_config, mock_sagemaker_session + ) + + expected_env_var = { + "OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model/" + } + self.maxDiff = None + + self.assertEqual( + mock_model.additional_model_data_sources, + [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": f"s3://{mock_js_gated_content_bucket.return_value}/hosting_prepacked_artifact_key", + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"AcceptEula": True}, + }, + } + ], + ) + + mock_model.add_tags.assert_called_once_with( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "jumpstart"} + ) + self.assertEqual(mock_model.env, expected_env_var) + + @patch( + "sagemaker.serve.utils.optimize_utils.get_eula_message", return_value="Accept eula message" + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket", + return_value="js_gated_content_bucket", + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_jumpstart_content_bucket", + return_value="js_content_bucket", + ) + @patch( + "sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs", + return_value=Mock(), + ) + @patch("sagemaker.model.Model") + def test_with_gated_js_model_and_accept_eula_false( + self, + mock_model, + mock_model_specs, + mock_js_content_bucket, + mock_js_gated_content_bucket, + mock_eula_message, + ): + mock_sagemaker_session = Mock() + mock_sagemaker_session.boto_region_name = "us-west-2" + + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "JumpStart", + "ModelID": "meta-textgeneration-llama-3-2-1b", + "AcceptEula": False, + } + + mock_model_specs.return_value.to_json.return_value = { + "gated_bucket": True, + "hosting_prepacked_artifact_key": "hosting_prepacked_artifact_key", + } + + self.assertRaisesRegex( + ValueError, + f"{mock_eula_message.return_value} Set `AcceptEula`=True in " + f"speculative_decoding_config once acknowledged.", + _jumpstart_speculative_decoding, + mock_model, + speculative_decoding_config, + mock_sagemaker_session, + ) + + +class TestCustomSpeculativeDecodingConfig(unittest.TestCase): + + @patch("sagemaker.model.Model") + def test_with_s3_hf(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "s3://bucket/djl-inference-2024-07-02-00-03-32-127/code" + } + + res_model = _custom_speculative_decoding(mock_model, speculative_decoding_config) + + mock_model.add_tags.assert_called_once_with( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "custom"} + ) + + self.assertEqual( + res_model.env, + {"OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model"}, + ) + self.assertEqual( + res_model.additional_model_data_sources, + [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": "s3://bucket/djl-inference-2024-07-02-00-03-32-127/code", + "S3DataType": "S3Prefix", + "CompressionType": "None", + }, + } + ], + ) + + @patch("sagemaker.model.Model") + def test_with_s3_js(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = { + "ModelSource": "s3://bucket/huggingface-pytorch-tgi-inference" + } + + res_model = _custom_speculative_decoding(mock_model, speculative_decoding_config, True) + + self.assertEqual( + res_model.additional_model_data_sources, + [ + { + "ChannelName": "draft_model", + "S3DataSource": { + "S3Uri": "s3://bucket/huggingface-pytorch-tgi-inference", + "S3DataType": "S3Prefix", + "CompressionType": "None", + "ModelAccessConfig": {"AcceptEula": True}, + }, + } + ], + ) + + @patch("sagemaker.model.Model") + def test_with_non_s3(self, mock_model): + mock_model.env = {} + mock_model.additional_model_data_sources = None + speculative_decoding_config = {"ModelSource": "huggingface-pytorch-tgi-inference"} + + res_model = _custom_speculative_decoding(mock_model, speculative_decoding_config, False) + + self.assertIsNone(res_model.additional_model_data_sources) + self.assertEqual( + res_model.env, + {"OPTION_SPECULATIVE_DRAFT_MODEL": "huggingface-pytorch-tgi-inference"}, + ) + + mock_model.add_tags.assert_called_once_with( + {"Key": Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER, "Value": "custom"} + ) diff --git a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py index e8273dd9a1..fc832ad02d 100644 --- a/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py +++ b/tests/unit/sagemaker/serve/utils/test_telemetry_logger.py @@ -12,8 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock from sagemaker.serve import Mode, ModelServer +from sagemaker.serve.model_format.mlflow.constants import MLFLOW_MODEL_PATH, MLFLOW_TRACKING_ARN from sagemaker.serve.utils.telemetry_logger import ( _send_telemetry, _capture_telemetry, @@ -24,7 +25,8 @@ from sagemaker.user_agent import SDK_VERSION MOCK_SESSION = Mock() -MOCK_FUNC_NAME = "Mock.deploy" +MOCK_DEPLOY_FUNC_NAME = "Mock.deploy" +MOCK_OPTIMIZE_FUNC_NAME = "Mock.optimize" MOCK_DJL_CONTAINER = ( "763104351884.dkr.ecr.us-west-2.amazonaws.com/" "djl-inference:0.25.0-deepspeed0.11.0-cu118" ) @@ -32,9 +34,16 @@ "763104351884.dkr.ecr.us-east-1.amazonaws.com/" "huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04" ) +MOCK_PYTORCH_CONTAINER = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.0.1-cpu-py310" +) MOCK_HUGGINGFACE_ID = "meta-llama/Llama-2-7b-hf" MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex") MOCK_ENDPOINT_ARN = "arn:aws:sagemaker:us-west-2:123456789012:endpoint/test" +MOCK_MODEL_METADATA_FOR_MLFLOW = { + MLFLOW_MODEL_PATH: "s3://some_path", + MLFLOW_TRACKING_ARN: "arn:aws:sagemaker:us-west-2:000000000000:mlflow-tracking-server/test", +} class ModelBuilderMock: @@ -42,11 +51,15 @@ def __init__(self): self.serve_settings = Mock() self.sagemaker_session = MOCK_SESSION - @_capture_telemetry(MOCK_FUNC_NAME) + @_capture_telemetry(MOCK_DEPLOY_FUNC_NAME) def mock_deploy(self, mock_exception_func=None): if mock_exception_func: mock_exception_func() + @_capture_telemetry(MOCK_OPTIMIZE_FUNC_NAME) + def mock_optimize(self, *args, **kwargs): + pass + class TestTelemetryLogger(unittest.TestCase): @patch("sagemaker.serve.utils.telemetry_logger._requests_helper") @@ -83,7 +96,7 @@ def test_capture_telemetry_decorator_djl_success(self, mock_send_telemetry): args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=4" "&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118" f"&x-sdkVersion={SDK_VERSION}" @@ -113,7 +126,7 @@ def test_capture_telemetry_decorator_djl_success_with_custom_image(self, mock_se args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=4" "&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118" f"&x-sdkVersion={SDK_VERSION}" @@ -143,7 +156,7 @@ def test_capture_telemetry_decorator_tgi_success(self, mock_send_telemetry): args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=6" "&x-imageTag=huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04" f"&x-sdkVersion={SDK_VERSION}" @@ -191,7 +204,7 @@ def test_capture_telemetry_decorator_handle_exception_success(self, mock_send_te args = mock_send_telemetry.call_args.args latency = str(args[5]).split("latency=")[1] expected_extra_str = ( - f"{MOCK_FUNC_NAME}" + f"{MOCK_DEPLOY_FUNC_NAME}" "&x-modelServer=4" "&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118" f"&x-sdkVersion={SDK_VERSION}" @@ -238,4 +251,96 @@ def test_construct_url_with_failure_reason_and_extra_info(self): f"&x-failureType={mock_failure_type}" f"&x-extra={mock_extra_info}" ) - self.assertEquals(ret_url, expected_base_url) + self.assertEqual(ret_url, expected_base_url) + + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_capture_telemetry_decorator_mlflow_success(self, mock_send_telemetry): + mock_model_builder = ModelBuilderMock() + mock_model_builder.serve_settings.telemetry_opt_out = False + mock_model_builder.image_uri = MOCK_PYTORCH_CONTAINER + mock_model_builder._is_mlflow_model = True + mock_model_builder.model_metadata = MOCK_MODEL_METADATA_FOR_MLFLOW + mock_model_builder._is_custom_image_uri = False + mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT + mock_model_builder.model_server = ModelServer.TORCHSERVE + mock_model_builder.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN + + mock_model_builder.mock_deploy() + + args = mock_send_telemetry.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_DEPLOY_FUNC_NAME}" + "&x-modelServer=1" + "&x-imageTag=pytorch-inference:2.0.1-cpu-py310" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-defaultImageUsage={ImageUriOption.DEFAULT_IMAGE.value}" + f"&x-endpointArn={MOCK_ENDPOINT_ARN}" + f"&x-mlflowModelPathType=2" + f"&x-mlflowTrackingServerArn={MOCK_MODEL_METADATA_FOR_MLFLOW[MLFLOW_TRACKING_ARN]}" + f"&x-latency={latency}" + ) + + mock_send_telemetry.assert_called_once_with( + "1", 3, MOCK_SESSION, None, None, expected_extra_str + ) + + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_capture_telemetry_decorator_optimize_with_default_configs(self, mock_send_telemetry): + mock_model_builder = ModelBuilderMock() + mock_model_builder.serve_settings.telemetry_opt_out = False + mock_model_builder.image_uri = None + mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT + mock_model_builder.model_server = ModelServer.TORCHSERVE + mock_model_builder.sagemaker_session.endpoint_arn = None + + mock_model_builder.mock_optimize() + + args = mock_send_telemetry.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_OPTIMIZE_FUNC_NAME}" + "&x-modelServer=1" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-latency={latency}" + ) + + mock_send_telemetry.assert_called_once_with( + "1", 3, MOCK_SESSION, None, None, expected_extra_str + ) + + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_capture_telemetry_decorator_optimize_with_custom_configs(self, mock_send_telemetry): + mock_model_builder = ModelBuilderMock() + mock_model_builder.serve_settings.telemetry_opt_out = False + mock_model_builder.image_uri = None + mock_model_builder.mode = Mode.SAGEMAKER_ENDPOINT + mock_model_builder.model_server = ModelServer.TORCHSERVE + mock_model_builder.sagemaker_session.endpoint_arn = None + mock_model_builder.is_fine_tuned = True + mock_model_builder.is_compiled = True + mock_model_builder.is_quantized = True + mock_model_builder.speculative_decoding_draft_model_source = "sagemaker" + + mock_speculative_decoding_config = MagicMock() + mock_config = {"ModelProvider": "sagemaker"} + mock_speculative_decoding_config.__getitem__.side_effect = mock_config.__getitem__ + + mock_model_builder.mock_optimize() + + args = mock_send_telemetry.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_OPTIMIZE_FUNC_NAME}" + "&x-modelServer=1" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-fineTuned=1" + f"&x-compiled=1" + f"&x-quantized=1" + f"&x-sdDraftModelSource=1" + f"&x-latency={latency}" + ) + + mock_send_telemetry.assert_called_once_with( + "1", 3, MOCK_SESSION, None, None, expected_extra_str + ) diff --git a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py new file mode 100644 index 0000000000..bd8db82a16 --- /dev/null +++ b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py @@ -0,0 +1,338 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import unittest +import pytest +import requests +from unittest.mock import Mock, patch, MagicMock +import boto3 +import sagemaker +from sagemaker.telemetry.constants import Feature +from sagemaker.telemetry.telemetry_logging import ( + _send_telemetry_request, + _telemetry_emitter, + _construct_url, + _get_accountId, + _requests_helper, + _get_region_or_default, + _get_default_sagemaker_session, + OS_NAME_VERSION, + PYTHON_VERSION, +) +from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file +from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException + +MOCK_SESSION = Mock() +MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex") +MOCK_FEATURE = Feature.SDK_DEFAULTS +MOCK_FUNC_NAME = "Mock.local_session.create_model" +MOCK_ENDPOINT_ARN = "arn:aws:sagemaker:us-west-2:123456789012:endpoint/test" + + +class LocalSagemakerClientMock: + def __init__(self): + self.sagemaker_session = MOCK_SESSION + + @_telemetry_emitter(MOCK_FEATURE, MOCK_FUNC_NAME) + def mock_create_model(self, mock_exception_func=None): + if mock_exception_func: + mock_exception_func() + + +class TestTelemetryLogging(unittest.TestCase): + @patch("sagemaker.telemetry.telemetry_logging._requests_helper") + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + def test_log_sucessfully(self, mock_get_accountId, mock_request_helper): + """Test to check if the telemetry logging is successful""" + MOCK_SESSION.boto_session.region_name = "us-west-2" + mock_get_accountId.return_value = "testAccountId" + _send_telemetry_request("someStatus", "1", MOCK_SESSION) + mock_request_helper.assert_called_with( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/" + "telemetry?x-accountId=testAccountId&x-status=someStatus&x-feature=1", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + def test_log_handle_exception(self, mock_get_accountId): + """Test to check if the exception is handled while logging telemetry""" + mock_get_accountId.side_effect = Exception("Internal error") + _send_telemetry_request("someStatus", "1", MOCK_SESSION) + self.assertRaises(Exception) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_success(self, mock_get_region, mock_get_accountId): + """Test to check the _send_telemetry_request function with success status""" + mock_get_accountId.return_value = "testAccountId" + mock_get_region.return_value = "us-west-2" + + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + mock_requests_helper.return_value = None + _send_telemetry_request(1, [1, 2], MagicMock(), None, None, "extra_info") + mock_requests_helper.assert_called_with( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/" + "telemetry?x-accountId=testAccountId&x-status=1&x-feature=1,2&x-extra=extra_info", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_failure(self, mock_get_region, mock_get_accountId): + """Test to check the _send_telemetry_request function with failure status""" + mock_get_accountId.return_value = "testAccountId" + mock_get_region.return_value = "us-west-2" + + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + mock_requests_helper.return_value = None + _send_telemetry_request( + 0, [1, 2], MagicMock(), "failure_reason", "failure_type", "extra_info" + ) + mock_requests_helper.assert_called_with( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/" + "telemetry?x-accountId=testAccountId&x-status=0&x-feature=1,2" + "&x-failureReason=failure_reason&x-failureType=failure_type&x-extra=extra_info", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_decorator_no_call_when_disabled( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test to check if the _telemetry_emitter decorator is not called when telemetry is disabled""" + mock_resolve_config.return_value = True + + assert not mock_send_telemetry_request.called + + @patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_decorator_success( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test to verify the _telemetry_emitter decorator with success status""" + mock_resolve_config.return_value = False + mock_local_client = LocalSagemakerClientMock() + mock_local_client.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN + mock_local_client.mock_create_model() + app_type = process_studio_metadata_file() + + args = mock_send_telemetry_request.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_FUNC_NAME}" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-env={PYTHON_VERSION}" + f"&x-sys={OS_NAME_VERSION}" + f"&x-platform={app_type}" + f"&x-endpointArn={MOCK_ENDPOINT_ARN}" + f"&x-latency={latency}" + ) + + mock_send_telemetry_request.assert_called_once_with( + 1, [1, 2], MOCK_SESSION, None, None, expected_extra_str + ) + + @patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_decorator_handle_exception_success( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test to verify the _telemetry_emitter decorator when function emits exception""" + mock_resolve_config.return_value = False + mock_local_client = LocalSagemakerClientMock() + mock_local_client.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN + app_type = process_studio_metadata_file() + + mock_exception = Mock() + mock_exception_obj = MOCK_EXCEPTION + mock_exception.side_effect = mock_exception_obj + + with self.assertRaises(ModelBuilderException) as _: + mock_local_client.mock_create_model(mock_exception) + + args = mock_send_telemetry_request.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_FUNC_NAME}" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-env={PYTHON_VERSION}" + f"&x-sys={OS_NAME_VERSION}" + f"&x-platform={app_type}" + f"&x-endpointArn={MOCK_ENDPOINT_ARN}" + f"&x-latency={latency}" + ) + + mock_send_telemetry_request.assert_called_once_with( + 0, + [1, 2], + MOCK_SESSION, + str(mock_exception_obj), + mock_exception_obj.__class__.__name__, + expected_extra_str, + ) + + def test_construct_url_with_failure_reason_and_extra_info(self): + """Test to verify the _construct_url function with failure reason and extra info""" + mock_accountId = "testAccountId" + mock_status = 0 + mock_feature = "1,2" + mock_failure_reason = str(MOCK_EXCEPTION) + mock_failure_type = MOCK_EXCEPTION.__class__.__name__ + mock_extra_info = "mock_extra_info" + mock_region = "us-west-2" + + resulted_url = _construct_url( + accountId=mock_accountId, + region=mock_region, + status=mock_status, + feature=mock_feature, + failure_reason=mock_failure_reason, + failure_type=mock_failure_type, + extra_info=mock_extra_info, + ) + + expected_base_url = ( + f"https://sm-pysdk-t-{mock_region}.s3.{mock_region}.amazonaws.com/telemetry?" + f"x-accountId={mock_accountId}" + f"&x-status={mock_status}" + f"&x-feature={mock_feature}" + f"&x-failureReason={mock_failure_reason}" + f"&x-failureType={mock_failure_type}" + f"&x-extra={mock_extra_info}" + ) + self.assertEqual(resulted_url, expected_base_url) + + @patch("sagemaker.telemetry.telemetry_logging.requests.get") + def test_requests_helper_success(self, mock_requests_get): + """Test to verify the _requests_helper function with success status""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_requests_get.return_value = mock_response + url = "https://example.com" + timeout = 10 + + response = _requests_helper(url, timeout) + + mock_requests_get.assert_called_once_with(url, timeout) + self.assertEqual(response, mock_response) + + @patch("sagemaker.telemetry.telemetry_logging.requests.get") + def test_requests_helper_exception(self, mock_requests_get): + """Test to verify the _requests_helper function with exception""" + mock_requests_get.side_effect = requests.exceptions.RequestException("Error making request") + url = "https://example.com" + timeout = 10 + + response = _requests_helper(url, timeout) + + mock_requests_get.assert_called_once_with(url, timeout) + self.assertIsNone(response) + + def test_get_accountId_success(self): + """Test to verify the _get_accountId function with success status""" + boto_mock = MagicMock(name="boto_session") + boto_mock.client("sts").get_caller_identity.return_value = {"Account": "testAccountId"} + session = sagemaker.Session(boto_session=boto_mock) + account_id = _get_accountId(session) + + self.assertEqual(account_id, "testAccountId") + + def test_get_accountId_exception(self): + """Test to verify the _get_accountId function with exception""" + sts_client_mock = MagicMock() + sts_client_mock.side_effect = Exception("Error creating STS client") + boto_mock = MagicMock(name="boto_session") + boto_mock.client("sts").get_caller_identity.return_value = sts_client_mock + session = sagemaker.Session(boto_session=boto_mock) + + with pytest.raises(Exception) as exception: + account_id = _get_accountId(session) + assert account_id is None + assert "Error creating STS client" in str(exception) + + def test_get_region_or_default_success(self): + """Test to verify the _get_region_or_default function with success status""" + mock_session = MagicMock() + mock_session.boto_session = MagicMock(region_name="us-east-1") + + region = _get_region_or_default(mock_session) + + assert region == "us-east-1" + + def test_get_region_or_default_exception(self): + """Test to verify the _get_region_or_default function with exception""" + mock_session = MagicMock() + mock_session.boto_session = MagicMock() + mock_session.boto_session.region_name.side_effect = Exception("Error creating boto session") + + with pytest.raises(Exception) as exception: + region = _get_region_or_default(mock_session) + assert region == "us-west-2" + assert "Error creating boto session" in str(exception) + + @patch.object(boto3.Session, "region_name", "us-west-2") + def test_get_default_sagemaker_session(self): + sagemaker_session = _get_default_sagemaker_session() + + assert isinstance(sagemaker_session, sagemaker.Session) is True + assert sagemaker_session.boto_session.region_name == "us-west-2" + + @patch.object(boto3.Session, "region_name", None) + def test_get_default_sagemaker_session_with_no_region(self): + with self.assertRaises(ValueError) as context: + _get_default_sagemaker_session() + + assert "Must setup local AWS configuration with a region supported by SageMaker." in str( + context.exception + ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_valid_region(self, mock_get_region, mock_get_accountId): + """Test to verify telemetry request is sent when region is valid""" + mock_get_accountId.return_value = "testAccountId" + mock_session = MagicMock() + + # Test with valid region + mock_get_region.return_value = "us-east-1" + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + _send_telemetry_request(1, [1, 2], mock_session) + # Assert telemetry request was sent + mock_requests_helper.assert_called_once_with( + "https://sm-pysdk-t-us-east-1.s3.us-east-1.amazonaws.com/telemetry?" + "x-accountId=testAccountId&x-status=1&x-feature=1,2", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_invalid_region(self, mock_get_region, mock_get_accountId): + """Test to verify telemetry request is not sent when region is invalid""" + mock_get_accountId.return_value = "testAccountId" + mock_session = MagicMock() + + # Test with invalid region + mock_get_region.return_value = "invalid-region" + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + _send_telemetry_request(1, [1, 2], mock_session) + # Assert telemetry request was not sent + mock_requests_helper.assert_not_called() diff --git a/tests/unit/sagemaker/test_studio.py b/tests/unit/sagemaker/test_studio.py index 47528e1f36..81302894ab 100644 --- a/tests/unit/sagemaker/test_studio.py +++ b/tests/unit/sagemaker/test_studio.py @@ -12,7 +12,8 @@ # language governing permissions and limitations under the License. # language governing permissions and limitations under the License. from __future__ import absolute_import - +import os +from pathlib import Path from sagemaker._studio import ( _append_project_tags, _find_config, @@ -21,6 +22,66 @@ ) +def test_find_config_cross_platform(tmpdir): + """Test _find_config works correctly across different platforms.""" + # Create a completely separate directory for isolated tests + import tempfile + + with tempfile.TemporaryDirectory() as isolated_root: + # Setup test directory structure for positive tests + config = tmpdir.join(".sagemaker-code-config") + config.write('{"sagemakerProjectId": "proj-1234"}') + + # Test 1: Direct parent directory + working_dir = tmpdir.mkdir("sub") + found_path = _find_config(working_dir) + assert found_path == config + + # Test 2: Deeply nested directories + nested_dir = tmpdir.mkdir("deep").mkdir("nested").mkdir("path") + found_path = _find_config(nested_dir) + assert found_path == config + + # Test 3: Start from root directory + import os + + root_dir = os.path.abspath(os.sep) + found_path = _find_config(root_dir) + assert found_path is None + + # Test 4: No config file in path - using truly isolated directory + isolated_path = Path(isolated_root) / "nested" / "path" + isolated_path.mkdir(parents=True) + found_path = _find_config(isolated_path) + assert found_path is None + + +def test_find_config_path_separators(tmpdir): + """Test _find_config handles different path separator styles. + + Tests: + 1. Forward slashes + 2. Backslashes + 3. Mixed separators + """ + # Setup + config = tmpdir.join(".sagemaker-code-config") + config.write('{"sagemakerProjectId": "proj-1234"}') + base_path = str(tmpdir) + + # Always include the OS native path and forward slashes (which are equivalent on all OS) + paths = [os.path.join(base_path, "dir1", "dir2"), "/".join([base_path, "dir1", "dir2"])] + + # Only on Windows add the backslashes and mixed separator test cases. + if os.name == "nt": + paths.extend(["\\".join([base_path, "dir1", "dir2"]), base_path + "/dir1\\dir2"]) + + for path in paths: + os.makedirs(path, exist_ok=True) + found_path = _find_config(path) + assert found_path == config + + def test_find_config(tmpdir): path = tmpdir.join(".sagemaker-code-config") path.write('{"sagemakerProjectId": "proj-1234"}') diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index e1c21cf662..c127d4b5ef 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -402,7 +402,7 @@ def test_pytorchxla_distribution( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, @@ -463,7 +463,7 @@ def test_default_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, @@ -519,7 +519,7 @@ def test_debug_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, @@ -575,7 +575,7 @@ def test_disable_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index e0d172f6e0..b7802f5a6b 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -349,7 +349,7 @@ def test_default_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, @@ -407,7 +407,7 @@ def test_debug_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, @@ -465,7 +465,7 @@ def test_disable_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( huggingface_training_compiler_version, diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 34a1236a7f..56c6e9966f 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -344,7 +344,7 @@ def test_pytorchxla_distribution( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( pytorch_training_compiler_version, @@ -403,7 +403,7 @@ def test_default_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( pytorch_training_compiler_version, @@ -458,7 +458,7 @@ def test_debug_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( pytorch_training_compiler_version, @@ -513,7 +513,7 @@ def test_disable_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( pytorch_training_compiler_version, diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index ac42bb53ab..54d701ad4e 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -289,7 +289,7 @@ def test_default( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( tensorflow_training_version, @@ -348,7 +348,7 @@ def test_byoc( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( tensorflow_training_version, @@ -399,7 +399,7 @@ def test_debug_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( tensorflow_training_version, @@ -450,7 +450,7 @@ def test_disable_compiler_config( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( tensorflow_training_version, diff --git a/tests/unit/sagemaker/workflow/test_notebook_job_step.py b/tests/unit/sagemaker/workflow/test_notebook_job_step.py index 9cc34ee243..aad6767953 100644 --- a/tests/unit/sagemaker/workflow/test_notebook_job_step.py +++ b/tests/unit/sagemaker/workflow/test_notebook_job_step.py @@ -12,11 +12,13 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import os import unittest + from mock import Mock, patch -from sagemaker.workflow.notebook_job_step import NotebookJobStep from sagemaker.workflow.functions import Join +from sagemaker.workflow.notebook_job_step import NotebookJobStep REGION = "us-west-2" PIPELINE_NAME = "test-pipeline-name" @@ -197,11 +199,11 @@ def test_invalid_inputs_required_fields_passed_as_none(self): in str(context.exception) ) self.assertTrue( - "The required input notebook(None) is not a valid file." in str(context.exception) + "The required input notebook (None) is not a valid file." in str(context.exception) ) self.assertTrue( - "The image uri(specified as None) is required and should be hosted in " - "same region of the session(us-west-2)." in str(context.exception) + "The image uri (specified as None) is required and should be hosted in " + "same region of the session (us-west-2)." in str(context.exception) ) self.assertTrue("The kernel name is required." in str(context.exception)) @@ -220,19 +222,19 @@ def test_invalid_paths_to_upload(self): ).arguments self.assertTrue( - "The required input notebook(path/non-existing-file) is not a valid file." + "The required input notebook (path/non-existing-file) is not a valid file." in str(context.exception) ) self.assertTrue( - "The initialization script(path/non-existing-file) is not a valid file." + "The initialization script (non-existing-script) is not a valid file." in str(context.exception) ) self.assertTrue( - "The path(/tmp/non-existing-folder) specified in additional dependencies " + "The path (/tmp/non-existing-folder) specified in additional dependencies " "does not exist." in str(context.exception) ) self.assertTrue( - "The path(path2/non-existing-file) specified in additional dependencies " + "The path (path2/non-existing-file) specified in additional dependencies " "does not exist." in str(context.exception) ) @@ -249,9 +251,9 @@ def test_image_uri_is_not_in_the_expected_region(self): ).arguments self.assertTrue( - "The image uri(specified as 236514542706.dkr.ecr.us-east-9.amazonaws.com/" + "The image uri (specified as 236514542706.dkr.ecr.us-east-9.amazonaws.com/" "sagemaker-data-science) is required and should be hosted in " - "same region of the session(us-west-2)." in str(context.exception) + "same region of the session (us-west-2)." in str(context.exception) ) def test_invalid_notebook_job_name(self): @@ -573,3 +575,62 @@ def _create_step_with_required_fields(self): image_uri=IMAGE_URI, kernel_name=KERNEL_NAME, ) + + def test_environment_variables_not_shared(self): + """Test that environment variables are not shared between NotebookJob steps""" + # Setup shared environment variables + shared_env_vars = {"test": "test"} + + # Create two steps with the same environment variables dictionary + step1 = NotebookJobStep( + name="step1", + input_notebook=INPUT_NOTEBOOK, + image_uri=IMAGE_URI, + kernel_name=KERNEL_NAME, + environment_variables=shared_env_vars, + ) + + step2 = NotebookJobStep( + name="step2", + input_notebook=INPUT_NOTEBOOK, + image_uri=IMAGE_URI, + kernel_name=KERNEL_NAME, + environment_variables=shared_env_vars, + ) + + # Get the arguments for both steps + step1_args = step1.arguments + step2_args = step2.arguments + + # Verify that the environment variables are different objects + self.assertIsNot( + step1_args["Environment"], + step2_args["Environment"], + "Environment dictionaries should be different objects", + ) + + # Verify that modifying one step's environment doesn't affect the other + step1_env = step1_args["Environment"] + step2_env = step2_args["Environment"] + + # Both should have the original test value + self.assertEqual(step1_env["test"], "test") + self.assertEqual(step2_env["test"], "test") + + # Modify step1's environment + step1_env["test"] = "modified" + + # Verify step2's environment remains unchanged + self.assertEqual(step2_env["test"], "test") + + # Verify notebook names are correct for each step + self.assertEqual( + step1_env["SM_INPUT_NOTEBOOK_NAME"], + os.path.basename(INPUT_NOTEBOOK), + "Step 1 should have its own notebook name", + ) + self.assertEqual( + step2_env["SM_INPUT_NOTEBOOK_NAME"], + os.path.basename(INPUT_NOTEBOOK), + "Step 2 should have its own notebook name", + ) diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index 14c2d442eb..d83bebd167 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -99,7 +99,7 @@ def test_pipeline_create_and_update_with_config_injection(sagemaker_session_mock RoleArn=pipeline_role_arn, ) pipeline.upsert() - assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=pipeline_role_arn, @@ -130,7 +130,7 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar role_arn=role_arn, parallelism_config=dict(MaxParallelExecutionSteps=10), ) - assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn, @@ -149,7 +149,7 @@ def test_pipeline_create_and_start_with_parallelism_config(sagemaker_session_moc role_arn=role_arn, parallelism_config=dict(MaxParallelExecutionSteps=10), ) - assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn, @@ -168,7 +168,7 @@ def test_pipeline_create_and_start_with_parallelism_config(sagemaker_session_moc # Specify ParallelismConfiguration to another value which will be honored in backend pipeline.start(parallelism_config=dict(MaxParallelExecutionSteps=20)) - assert sagemaker_session_mock.sagemaker_client.start_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with( PipelineName="MyPipeline", ParallelismConfiguration={"MaxParallelExecutionSteps": 20}, ) @@ -209,7 +209,7 @@ def test_pipeline_update(sagemaker_session_mock, role_arn): assert not pipeline.steps pipeline.update(role_arn=role_arn) assert len(json.loads(pipeline.definition())["Steps"]) == 0 - assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn ) @@ -253,7 +253,7 @@ def test_pipeline_update(sagemaker_session_mock, role_arn): pipeline.update(role_arn=role_arn) assert len(json.loads(pipeline.definition())["Steps"]) == 3 - assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn ) @@ -345,7 +345,11 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar role_arn=role_arn, parallelism_config=dict(MaxParallelExecutionSteps=10), ) - assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + pipeline.update( + role_arn=role_arn, + parallelism_config={"MaxParallelExecutionSteps": 10}, + ) + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn, @@ -387,7 +391,6 @@ def _raise_does_already_exists_client_error(**kwargs): sagemaker_session_mock.sagemaker_client.create_pipeline = Mock( name="create_pipeline", side_effect=_raise_does_already_exists_client_error ) - sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = { "PipelineArn": "pipeline-arn" } @@ -418,15 +421,19 @@ def _raise_does_already_exists_client_error(**kwargs): sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_once_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn ) - assert sagemaker_session_mock.sagemaker_client.list_tags.called_with( - ResourceArn="mock_pipeline_arn" - ) + sagemaker_session_mock.sagemaker_client.list_tags.assert_called_with(ResourceArn="pipeline-arn") tags.append({"Key": "dummy", "Value": "dummy_tag"}) - assert sagemaker_session_mock.sagemaker_client.add_tags.called_with( - ResourceArn="mock_pipeline_arn", Tags=tags + sagemaker_session_mock.sagemaker_client.add_tags.assert_called_with( + ResourceArn="pipeline-arn", Tags=tags ) + sagemaker_session_mock.sagemaker_client.list_pipeline_versions.return_value = { + "PipelineVersionSummaries": [{"PipelineVersionId": 2}] + } + + assert pipeline.latest_pipeline_version_id == 2 + def test_pipeline_upsert_create_unexpected_failure(sagemaker_session_mock, role_arn): @@ -474,18 +481,11 @@ def _raise_unexpected_client_error(**kwargs): sagemaker_session_mock.sagemaker_client.add_tags.assert_not_called() -def test_pipeline_upsert_resourse_doesnt_exist(sagemaker_session_mock, role_arn): +def test_pipeline_upsert_resource_doesnt_exist(sagemaker_session_mock, role_arn): # case 3: resource does not exist sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(name="create_pipeline") - sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = { - "PipelineArn": "pipeline-arn" - } - sagemaker_session_mock.sagemaker_client.list_tags.return_value = { - "Tags": [{"Key": "dummy", "Value": "dummy_tag"}] - } - tags = [ {"Key": "foo", "Value": "abc"}, {"Key": "bar", "Value": "xyz"}, @@ -523,7 +523,7 @@ def test_pipeline_delete(sagemaker_session_mock): sagemaker_session=sagemaker_session_mock, ) pipeline.delete() - assert sagemaker_session_mock.sagemaker_client.delete_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.delete_pipeline.assert_called_with( PipelineName="MyPipeline", ) @@ -536,10 +536,15 @@ def test_pipeline_describe(sagemaker_session_mock): sagemaker_session=sagemaker_session_mock, ) pipeline.describe() - assert sagemaker_session_mock.sagemaker_client.describe_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.describe_pipeline.assert_called_with( PipelineName="MyPipeline", ) + pipeline.describe(pipeline_version_id=5) + sagemaker_session_mock.sagemaker_client.describe_pipeline.assert_called_with( + PipelineName="MyPipeline", PipelineVersionId=5 + ) + def test_pipeline_start(sagemaker_session_mock): sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = { @@ -552,20 +557,25 @@ def test_pipeline_start(sagemaker_session_mock): sagemaker_session=sagemaker_session_mock, ) pipeline.start() - assert sagemaker_session_mock.start_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with( PipelineName="MyPipeline", ) pipeline.start(execution_display_name="pipeline-execution") - assert sagemaker_session_mock.start_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with( PipelineName="MyPipeline", PipelineExecutionDisplayName="pipeline-execution" ) pipeline.start(parameters=dict(alpha="epsilon")) - assert sagemaker_session_mock.start_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with( PipelineName="MyPipeline", PipelineParameters=[{"Name": "alpha", "Value": "epsilon"}] ) + pipeline.start(pipeline_version_id=5) + sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with( + PipelineName="MyPipeline", PipelineVersionId=5 + ) + def test_pipeline_start_selective_execution(sagemaker_session_mock): sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = { @@ -807,6 +817,29 @@ def test_pipeline_list_executions(sagemaker_session_mock): assert executions["NextToken"] == "token" +def test_pipeline_list_versions(sagemaker_session_mock): + sagemaker_session_mock.sagemaker_client.list_pipeline_versions.return_value = { + "PipelineVersionSummaries": [Mock()], + "NextToken": "token", + } + pipeline = Pipeline( + name="MyPipeline", + parameters=[ParameterString("alpha", "beta"), ParameterString("gamma", "delta")], + steps=[], + sagemaker_session=sagemaker_session_mock, + ) + versions = pipeline.list_pipeline_versions() + assert len(versions["PipelineVersionSummaries"]) == 1 + assert versions["NextToken"] == "token" + + sagemaker_session_mock.sagemaker_client.list_pipeline_versions.return_value = { + "PipelineVersionSummaries": [Mock(), Mock()], + } + versions = pipeline.list_pipeline_versions(next_token=versions["NextToken"]) + assert len(versions["PipelineVersionSummaries"]) == 2 + assert "NextToken" not in versions + + def test_pipeline_build_parameters_from_execution(sagemaker_session_mock): pipeline = Pipeline( name="MyPipeline", @@ -821,10 +854,8 @@ def test_pipeline_build_parameters_from_execution(sagemaker_session_mock): pipeline_execution_arn=reference_execution_arn, parameter_value_overrides=parameter_value_overrides, ) - assert ( - sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with( - PipelineExecutionArn=reference_execution_arn - ) + sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_called_with( + PipelineExecutionArn=reference_execution_arn ) assert len(parameters) == 1 assert parameters["TestParameterName"] == "NewParameterValue" @@ -850,10 +881,8 @@ def test_pipeline_build_parameters_from_execution_with_invalid_overrides(sagemak + f"are not present in the pipeline execution: {reference_execution_arn}" in str(error) ) - assert ( - sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with( - PipelineExecutionArn=reference_execution_arn - ) + sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_called_with( + PipelineExecutionArn=reference_execution_arn ) @@ -908,24 +937,23 @@ def test_pipeline_execution_basics(sagemaker_session_mock): ) execution = pipeline.start() execution.stop() - assert sagemaker_session_mock.sagemaker_client.stop_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.stop_pipeline_execution.assert_called_with( PipelineExecutionArn="my:arn" ) execution.describe() - assert sagemaker_session_mock.sagemaker_client.describe_pipeline_execution.called_with( + sagemaker_session_mock.sagemaker_client.describe_pipeline_execution.assert_called_with( PipelineExecutionArn="my:arn" ) steps = execution.list_steps() - assert sagemaker_session_mock.sagemaker_client.describe_pipeline_execution_steps.called_with( + sagemaker_session_mock.sagemaker_client.list_pipeline_execution_steps.assert_called_with( PipelineExecutionArn="my:arn" ) assert len(steps) == 1 list_parameters_response = execution.list_parameters() - assert ( - sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with( - PipelineExecutionArn="my:arn" - ) + sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_called_with( + PipelineExecutionArn="my:arn" ) + parameter_list = list_parameters_response["PipelineParameters"] assert len(parameter_list) == 1 assert parameter_list[0]["Name"] == "TestParameterName" diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index e0b9aae824..ddc76a05f7 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -55,7 +55,7 @@ MODEL_NAME = "gisele" MODEL_REPACKING_IMAGE_URI = ( - "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3" + "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:1.2-1-cpu-py3" ) @@ -1219,8 +1219,7 @@ def test_estimator_transformer_with_model_repack_with_estimator(estimator, sourc assert arguments == { "AlgorithmSpecification": { "TrainingInputMode": "File", - "TrainingImage": "246618743249.dkr.ecr.us-west-2.amazonaws.com/" - + "sagemaker-scikit-learn:0.23-1-cpu-py3", + "TrainingImage": MODEL_REPACKING_IMAGE_URI, }, "ProfilerConfig": {"DisableProfiler": True}, "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"}, diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index b3d667a1c3..84906ce620 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -671,7 +671,7 @@ def test_processing_step_normalizes_args_with_local_code(mock_normalize_args, sc mock_normalize_args.return_value = [step.inputs, step.outputs] step.to_request() mock_normalize_args.assert_called_with( - job_name="MyProcessingStep-3e89f0c7e101c356cbedf27d9d27e9db", + job_name=None, arguments=step.job_arguments, inputs=step.inputs, outputs=step.outputs, diff --git a/tests/unit/sagemaker/workflow/test_transform_step.py b/tests/unit/sagemaker/workflow/test_transform_step.py index d22965dae8..19471228d6 100644 --- a/tests/unit/sagemaker/workflow/test_transform_step.py +++ b/tests/unit/sagemaker/workflow/test_transform_step.py @@ -70,6 +70,7 @@ custom_step.properties.OutputDataConfig.S3OutputPath, ], ) +@pytest.mark.flaky(reruns=5, reruns_delay=1) def test_transform_step_with_transformer(model_name, data, output_path, pipeline_session): transformer = Transformer( model_name=model_name, diff --git a/tests/unit/sagemaker/workflow/test_utilities.py b/tests/unit/sagemaker/workflow/test_utilities.py index e65d3ea933..b284ced91e 100644 --- a/tests/unit/sagemaker/workflow/test_utilities.py +++ b/tests/unit/sagemaker/workflow/test_utilities.py @@ -31,14 +31,14 @@ def test_hash_file(): with tempfile.NamedTemporaryFile() as tmp: tmp.write("hashme".encode()) hash = hash_file(tmp.name) - assert hash == "d41d8cd98f00b204e9800998ecf8427e" + assert hash == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" def test_hash_file_uri(): with tempfile.NamedTemporaryFile() as tmp: tmp.write("hashme".encode()) hash = hash_file(f"file:///{tmp.name}") - assert hash == "d41d8cd98f00b204e9800998ecf8427e" + assert hash == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" def test_hash_files_or_dirs_with_file(): diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index 48b1d762c3..b18ed71f9b 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -80,10 +80,11 @@ def test_repack_model_step(estimator): assert hyperparameters["inference_script"] == '"dummy_script.py"' assert hyperparameters["model_archive"] == '"s3://my-bucket/model.tar.gz"' assert hyperparameters["sagemaker_program"] == f'"{REPACK_SCRIPT_LAUNCHER}"' - assert ( - hyperparameters["sagemaker_submit_directory"] - == '"s3://my-bucket/MyRepackModelStep-b5ea77f701b47a8d075605497462ccc2/source/sourcedir.tar.gz"' - ) + + # ex: "gits3://my-bucket/sagemaker-scikit-learn-2025-04-07-20-39-38-854/source/sourcedir.tar.gz" + sagemaker_submit_directory = hyperparameters["sagemaker_submit_directory"] + assert sagemaker_submit_directory.startswith('"s3://my-bucket/sagemaker-scikit-learn-') + assert sagemaker_submit_directory.endswith('/source/sourcedir.tar.gz"') del request_dict["Arguments"]["HyperParameters"] del request_dict["Arguments"]["AlgorithmSpecification"]["TrainingImage"] diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 8ad2ae0bab..0ac8cb0888 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -15,7 +15,6 @@ import logging import json import os -from distutils.util import strtobool import pytest from mock import MagicMock, Mock, ANY @@ -174,7 +173,15 @@ def test_additional_hyperparameters(sagemaker_session, chainer_version, chainer_ framework_version=chainer_version, py_version=chainer_py_version, ) - assert bool(strtobool(chainer.hyperparameters()["sagemaker_use_mpi"])) + + assert chainer.hyperparameters()["sagemaker_use_mpi"].lower() in ( + "y", + "yes", + "t", + "true", + "on", + "1", + ) assert int(chainer.hyperparameters()["sagemaker_num_processes"]) == 4 assert int(chainer.hyperparameters()["sagemaker_process_slots_per_host"]) == 10 assert ( @@ -354,7 +361,7 @@ def test_chainer(strftime, time, sagemaker_session, chainer_version, chainer_py_ sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job(chainer_version, chainer_py_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs diff --git a/tests/unit/test_common.py b/tests/unit/test_common.py index 8fe7383fe4..9fe49ad448 100644 --- a/tests/unit/test_common.py +++ b/tests/unit/test_common.py @@ -16,12 +16,12 @@ import tempfile import pytest import itertools +from sagemaker.deserializers import RecordDeserializer +from sagemaker.serializers import RecordSerializer from scipy.sparse import coo_matrix from sagemaker.amazon.common import ( - RecordDeserializer, write_numpy_to_dense_tensor, read_recordio, - RecordSerializer, write_spmatrix_to_sparse_tensor, ) from sagemaker.amazon.record_pb2 import Record diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index 9f0d68f01d..dca1d3dc85 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -13,6 +13,8 @@ from __future__ import absolute_import import datetime +from unittest.mock import Mock + import pytest from botocore.exceptions import ClientError from mock import MagicMock @@ -37,13 +39,32 @@ def sagemaker_session(): return sagemaker_session +@pytest.fixture() +def sagemaker_session_with_bucket_name_and_prefix(): + boto_mock = MagicMock(name="boto_session", region_name=REGION) + boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID} + sagemaker_session = sagemaker.Session( + boto_session=boto_mock, + default_bucket="XXXXXXXXXXXXX", + default_bucket_prefix="sample-prefix", + ) + sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + return sagemaker_session + + def test_default_bucket_s3_create_call(sagemaker_session): error = ClientError( error_response={"Error": {"Code": "404", "Message": "Not Found"}}, operation_name="foo", ) - sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error - bucket_name = sagemaker_session.default_bucket() + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = Mock( + side_effect=error + ) + + try: + bucket_name = sagemaker_session.default_bucket() + except ClientError: + pass create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls _1, _2, create_kwargs = create_calls[0] @@ -53,7 +74,6 @@ def test_default_bucket_s3_create_call(sagemaker_session): "CreateBucketConfiguration": {"LocationConstraint": "us-west-2"}, "Bucket": bucket_name, } - assert sagemaker_session._default_bucket == bucket_name def test_default_bucket_s3_needs_access(sagemaker_session, caplog): @@ -89,6 +109,30 @@ def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime assert sagemaker_session._default_bucket is None +def test_default_bucket_with_prefix_s3_needs_bucket_owner_access( + sagemaker_session_with_bucket_name_and_prefix, datetime_obj, caplog +): + with pytest.raises(ClientError): + error = ClientError( + error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, + operation_name="foo", + ) + sagemaker_session_with_bucket_name_and_prefix.boto_session.resource( + "s3" + ).meta.client.list_objects_v2.side_effect = error + sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").Bucket( + name=DEFAULT_BUCKET_NAME + ).creation_date = None + sagemaker_session_with_bucket_name_and_prefix.default_bucket() + + error_message = "Please try again after adding appropriate access." + assert error_message in caplog.text + assert sagemaker_session_with_bucket_name_and_prefix._default_bucket is None + sagemaker_session_with_bucket_name_and_prefix.boto_session.resource( + "s3" + ).meta.client.list_objects_v2.assert_called_once() + + def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog): sagemaker_session._default_bucket_name_override = "custom-bucket-override" error = ClientError( diff --git a/tests/unit/test_djl_inference.py b/tests/unit/test_djl_inference.py index cc8a99cf1c..abdd86fbfc 100644 --- a/tests/unit/test_djl_inference.py +++ b/tests/unit/test_djl_inference.py @@ -12,42 +12,25 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import logging - -import json -from json import JSONDecodeError - import pytest -from mock import Mock, MagicMock -from mock import patch, mock_open +from mock import Mock from sagemaker.djl_inference import ( - defaults, DJLModel, - DJLPredictor, - HuggingFaceAccelerateModel, - DeepSpeedModel, ) -from sagemaker.djl_inference.model import DJLServingEngineEntryPointDefaults -from sagemaker.s3_utils import s3_path_join from sagemaker.session_settings import SessionSettings -from tests.unit import ( - _test_default_bucket_and_prefix_combinations, - DEFAULT_S3_BUCKET_NAME, - DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, -) +from sagemaker import image_uris VALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model" -INVALID_UNCOMPRESSED_MODEL_DATA = "s3://mybucket/model.tar.gz" +VALID_COMPRESSED_MODEL_DATA = "s3://mybucket/model.tar.gz" HF_MODEL_ID = "hf_hub_model_id" -ENTRY_POINT = "entrypoint.py" -SOURCE_DIR = "source_dir/" -ENV = {"ENV_VAR": "env_value"} ROLE = "dummy_role" REGION = "us-west-2" -BUCKET = "mybucket" -IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazon.com/djl-inference:0.24.0-deepspeed0.10.0-cu118" -GPU_INSTANCE = "ml.g5.12xlarge" +VERSION = "latest" + +LMI_IMAGE_URI = image_uris.retrieve(framework="djl-lmi", version=VERSION, region=REGION) +TRT_IMAGE_URI = image_uris.retrieve(framework="djl-tensorrtllm", version=VERSION, region=REGION) +TNX_IMAGE_URI = image_uris.retrieve(framework="djl-neuronx", version=VERSION, region=REGION) @pytest.fixture() @@ -66,756 +49,134 @@ def sagemaker_session(): endpoint_from_production_variants=Mock(name="endpoint_from_production_variants"), default_bucket_prefix=None, ) - session.default_bucket = Mock(name="default_bucket", return_value=BUCKET) + session.default_bucket = Mock(name="default_bucket", return_value="bucket") # For tests which doesn't verify config file injection, operate with empty config session.sagemaker_config = {} return session -def test_create_model_invalid_s3_uri(): - with pytest.raises(ValueError) as invalid_s3_data: - _ = DJLModel( - INVALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - ) - assert str(invalid_s3_data.value).startswith( - "DJLModel does not support model artifacts in tar.gz" - ) - - -@patch("urllib.request.urlopen") -def test_create_model_valid_hf_hub_model_id( - mock_urlopen, - sagemaker_session, -): - model_config = { - "model_type": "opt", - "num_attention_heads": 4, - } - - cm = MagicMock() - cm.getcode.return_value = 200 - cm.read.return_value = json.dumps(model_config).encode("utf-8") - cm.__enter__.return_value = cm - mock_urlopen.return_value = cm +def test_create_djl_model_only_model_id(sagemaker_session): model = DJLModel( - HF_MODEL_ID, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - assert model.engine == DJLServingEngineEntryPointDefaults.DEEPSPEED - expected_url = f"https://huggingface.co/{HF_MODEL_ID}/raw/main/config.json" - mock_urlopen.assert_any_call(expected_url) - - serving_properties = model.generate_serving_properties() - assert serving_properties["option.model_id"] == HF_MODEL_ID - - -@patch("json.load") -@patch("urllib.request.urlopen") -def test_create_model_invalid_hf_hub_model_id( - mock_urlopen, - json_load, - sagemaker_session, -): - expected_url = f"https://huggingface.co/{HF_MODEL_ID}/raw/main/config.json" - with pytest.raises(ValueError) as invalid_model_id: - cm = MagicMock() - cm.__enter__.return_value = cm - mock_urlopen.return_value = cm - json_load.side_effect = JSONDecodeError("", "", 0) - _ = DJLModel( - HF_MODEL_ID, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - mock_urlopen.assert_any_call(expected_url) - assert str(invalid_model_id.value).startswith( - "Did not find a config.json or model_index.json file in huggingface hub" - ) - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_create_model_automatic_engine_selection(mock_s3_list, mock_read_file, sagemaker_session): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - hf_model_config = { - "model_type": "t5", - "num_attention_heads": 4, - } - mock_read_file.return_value = json.dumps(hf_model_config) - hf_model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, + model_id=VALID_UNCOMPRESSED_MODEL_DATA, sagemaker_session=sagemaker_session, - number_of_partitions=4, + role=ROLE, ) - assert hf_model.engine == DJLServingEngineEntryPointDefaults.FASTER_TRANSFORMER - - hf_model_config = { - "model_type": "gpt2", - "num_attention_heads": 25, - } - mock_read_file.return_value = json.dumps(hf_model_config) - hf_model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - assert hf_model.engine == DJLServingEngineEntryPointDefaults.HUGGINGFACE_ACCELERATE - - for model_type in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES: - ds_model_config = { - "model_type": model_type, - "num_attention_heads": 12, - } - mock_read_file.return_value = json.dumps(ds_model_config) - ds_model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=2, - ) - mock_s3_list.assert_any_call( - VALID_UNCOMPRESSED_MODEL_DATA, sagemaker_session=sagemaker_session - ) - if model_type == defaults.STABLE_DIFFUSION_MODEL_TYPE: - assert ds_model.engine == DJLServingEngineEntryPointDefaults.STABLE_DIFFUSION - else: - assert ds_model.engine == DJLServingEngineEntryPointDefaults.DEEPSPEED + assert model.engine == "Python" + assert model.image_uri == LMI_IMAGE_URI + assert model.env == {"HF_MODEL_ID": VALID_UNCOMPRESSED_MODEL_DATA, "OPTION_ENGINE": "Python"} -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_create_deepspeed_model(mock_s3_list, mock_read_file, sagemaker_session): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - ds_model_config = { - "model_type": "opt", - "n_head": 12, - } - mock_read_file.return_value = json.dumps(ds_model_config) - ds_model = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - tensor_parallel_degree=4, - ) - assert ds_model.engine == DJLServingEngineEntryPointDefaults.DEEPSPEED - - ds_model_config = { - "model_type": "opt", - "n_head": 25, - } - mock_read_file.return_value = json.dumps(ds_model_config) - with pytest.raises(ValueError) as invalid_partitions: - _ = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - tensor_parallel_degree=4, - ) - assert str(invalid_partitions.value).startswith("The number of attention heads is not evenly") - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_create_huggingface_model(mock_s3_list, mock_read_file, sagemaker_session): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - hf_model_config = { - "model_type": "opt", - "n_head": 12, - } - mock_read_file.return_value = json.dumps(hf_model_config) - hf_model = HuggingFaceAccelerateModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - assert hf_model.engine == DJLServingEngineEntryPointDefaults.HUGGINGFACE_ACCELERATE - - hf_model_config = { - "model_type": "t5", - "n_head": 13, - } - mock_read_file.return_value = json.dumps(hf_model_config) - hf_model = HuggingFaceAccelerateModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - assert hf_model.engine == DJLServingEngineEntryPointDefaults.HUGGINGFACE_ACCELERATE - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_model_unsupported_methods(mock_s3_list, mock_read_file, sagemaker_session): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "opt", - "n_head": 12, - } - mock_read_file.return_value = json.dumps(model_config) +def test_create_djl_model_only_model_data(sagemaker_session): model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, + model_data={ + "S3DataSource": { + "S3Uri": VALID_COMPRESSED_MODEL_DATA, + "S3DataType": "S3Object", + "CompressionType": "Gzip", + } + }, sagemaker_session=sagemaker_session, + role=ROLE, ) + assert model.engine == "Python" + assert model.image_uri == LMI_IMAGE_URI + assert model.env == {"OPTION_ENGINE": "Python"} - with pytest.raises(NotImplementedError) as invalid_method: - model.package_for_edge() - assert str(invalid_method.value).startswith("DJLModels do not support Sagemaker Edge") - - with pytest.raises(NotImplementedError) as invalid_method: - model.compile() - assert str(invalid_method.value).startswith( - "DJLModels do not currently support compilation with SageMaker Neo" - ) - - with pytest.raises(NotImplementedError) as invalid_method: - model.transformer() - assert str(invalid_method.value).startswith( - "DJLModels do not currently support Batch Transform inference jobs" - ) - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_deploy_base_model_invalid_instance(mock_s3_list, mock_read_file, sagemaker_session): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "gpt-neox", - "n_head": 25, - } - mock_read_file.return_value = json.dumps(model_config) +def test_create_djl_model_with_task(sagemaker_session): model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, + model_id=VALID_UNCOMPRESSED_MODEL_DATA, sagemaker_session=sagemaker_session, - number_of_partitions=4, - ) - - with pytest.raises(ValueError) as invalid_instance: - _ = model.deploy("ml.m5.12xlarge") - assert str(invalid_instance.value).startswith("Invalid instance type. DJLModels only support") - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_generate_deepspeed_serving_properties_invalid_configurations( - mock_s3_list, mock_read_file, sagemaker_session -): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "bert", - "n_head": 4, - } - mock_read_file.return_value = json.dumps(model_config) - model = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - tensor_parallel_degree=4, - enable_cuda_graph=True, - ) - with pytest.raises(ValueError) as invalid_config: - _ = model.generate_serving_properties() - assert str(invalid_config.value).startswith("enable_cuda_graph is not supported") - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_generate_huggingface_serving_properties_invalid_configurations( - mock_s3_list, mock_read_file, sagemaker_session -): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "t5", - "n_head": 4, - } - mock_read_file.return_value = json.dumps(model_config) - model = HuggingFaceAccelerateModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - dtype="fp16", - load_in_8bit=True, - ) - with pytest.raises(ValueError) as invalid_config: - _ = model.generate_serving_properties() - assert str(invalid_config.value).startswith("Set dtype='int8' to use load_in_8bit") - - model = HuggingFaceAccelerateModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=2, - device_id=1, - ) - with pytest.raises(ValueError) as invalid_config: - _ = model.generate_serving_properties() - assert str(invalid_config.value).startswith( - "device_id cannot be set when number_of_partitions is > 1" - ) - - -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_generate_serving_properties_with_valid_configurations( - mock_s3_list, mock_read_file, sagemaker_session -): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "gpt-neox", - "n_head": 25, - } - mock_read_file.return_value = json.dumps(model_config) - model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - min_workers=1, - max_workers=3, - job_queue_size=4, - dtype="fp16", - parallel_loading=True, - model_loading_timeout=120, - prediction_timeout=4, - source_dir=SOURCE_DIR, - entry_point=ENTRY_POINT, - task="text-classification", - ) - serving_properties = model.generate_serving_properties() - expected_dict = { - "engine": "Python", - "option.entryPoint": ENTRY_POINT, - "option.model_id": VALID_UNCOMPRESSED_MODEL_DATA, - "option.tensor_parallel_degree": 4, - "option.task": "text-classification", - "option.dtype": "fp16", - "minWorkers": 1, - "maxWorkers": 3, - "job_queue_size": 4, - "option.parallel_loading": True, - "option.model_loading_timeout": 120, - "option.prediction_timeout": 4, - } - assert serving_properties == expected_dict - serving_properties.clear() - expected_dict.clear() - - model_config = { - "model_type": "opt", - "n_head": 4, - } - mock_read_file.return_value = json.dumps(model_config) - model = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - tensor_parallel_degree=1, + role=ROLE, task="text-generation", - dtype="bf16", - max_tokens=2048, - low_cpu_mem_usage=True, - enable_cuda_graph=True, ) - serving_properties = model.generate_serving_properties() - expected_dict = { - "engine": "DeepSpeed", - "option.entryPoint": "djl_python.deepspeed", - "option.model_id": VALID_UNCOMPRESSED_MODEL_DATA, - "option.tensor_parallel_degree": 1, - "option.task": "text-generation", - "option.dtype": "bf16", - "option.max_tokens": 2048, - "option.enable_cuda_graph": True, - "option.low_cpu_mem_usage": True, - "option.triangular_masking": True, - "option.return_tuple": True, + assert model.engine == "Python" + assert model.image_uri == LMI_IMAGE_URI + assert model.env == { + "HF_MODEL_ID": VALID_UNCOMPRESSED_MODEL_DATA, + "OPTION_ENGINE": "Python", + "HF_TASK": "text-generation", } - assert serving_properties == expected_dict - serving_properties.clear() - expected_dict.clear() - model = HuggingFaceAccelerateModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=1, - device_id=4, - device_map="balanced", - dtype="fp32", - low_cpu_mem_usage=False, - ) - serving_properties = model.generate_serving_properties() - expected_dict = { - "engine": "Python", - "option.entryPoint": "djl_python.huggingface", - "option.model_id": VALID_UNCOMPRESSED_MODEL_DATA, - "option.tensor_parallel_degree": 1, - "option.dtype": "fp32", - "option.device_id": 4, - "option.device_map": "balanced", - } - assert serving_properties == expected_dict - - -@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI) -@patch("shutil.rmtree") -@patch("sagemaker.utils.base_name_from_image") -@patch("tempfile.mkdtemp") -@patch("sagemaker.container_def") -@patch("sagemaker.utils._tmpdir") -@patch("sagemaker.utils._create_or_update_code_dir") -@patch("sagemaker.fw_utils.tar_and_upload_dir") -@patch("os.mkdir") -@patch("os.path.exists") -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -def test_deploy_model_no_local_code( - mock_s3_list, - mock_read_file, - mock_path_exists, - mock_mkdir, - mock_tar_upload, - mock_create_code_dir, - mock_tmpdir, - mock_container_def, - mock_mktmp, - mock_name_from_base, - mock_shutil_rmtree, - mock_imguri_retrieve, - sagemaker_session, -): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "bloom", - "n_heads": 120, - } - mock_read_file.return_value = json.dumps(model_config) model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, + model_id=HF_MODEL_ID, sagemaker_session=sagemaker_session, - number_of_partitions=4, - dtype="fp16", - container_log_level=logging.DEBUG, - env=ENV, + role=ROLE, + task="text-embedding", ) - - assert model.image_uri is None - - mock_path_exists.side_effect = [True, False, True] - mock_mktmp.return_value = "/tmp/dir" - mock_tar_upload.return_value = Mock(s3_prefix="s3prefix") - expected_env = {"ENV_VAR": "env_value", "SERVING_OPTS": '"-Dai.djl.logging.level=debug"'} - with patch("builtins.open", mock_open()) as fake_serving_properties: - predictor = model.deploy(GPU_INSTANCE) - - assert isinstance(predictor, DJLPredictor) - mock_mktmp.assert_called_once_with(prefix="tmp", suffix="", dir=None) - mock_mkdir.assert_called() - assert fake_serving_properties.call_count == 2 - fake_serving_properties.assert_any_call("/tmp/dir/code/serving.properties", "w+") - fake_serving_properties.assert_any_call("/tmp/dir/code/serving.properties", "r") - model.sagemaker_session.create_model.assert_called_once() - mock_container_def.assert_called_once_with( - IMAGE_URI, model_data_url="s3prefix", env=expected_env - ) - - -@patch("sagemaker.image_uris.retrieve", return_value=IMAGE_URI) -@patch("shutil.rmtree") -@patch("sagemaker.utils.base_name_from_image") -@patch("tempfile.mkdtemp") -@patch("sagemaker.container_def") -@patch("sagemaker.utils._tmpdir") -@patch("sagemaker.utils._create_or_update_code_dir") -@patch("os.mkdir") -@patch("os.path.exists") -@patch("sagemaker.s3.S3Downloader.read_file") -@patch("sagemaker.s3.S3Downloader.list") -@patch("sagemaker.s3.S3Uploader.upload") -@patch("sagemaker.estimator.Estimator.fit") -@patch("sagemaker.fw_utils.model_code_key_prefix") -@patch("os.path.isfile") -@patch("boto3.client") -def test_partition( - mock_client, - mock_is_file, - mock_model_key_prefix, - mock_estimator_fit, - mock_upload, - mock_s3_list, - mock_read_file, - mock_path_exists, - mock_mkdir, - mock_create_code_dir, - mock_tmpdir, - mock_container_def, - mock_mktmp, - mock_name_from_base, - mock_shutil_rmtree, - mock_imguri_retrieve, - sagemaker_session, -): - mock_s3_list.return_value = [VALID_UNCOMPRESSED_MODEL_DATA + "/config.json"] - model_config = { - "model_type": "bloom", - "n_heads": 120, + assert model.engine == "OnnxRuntime" + assert model.image_uri == LMI_IMAGE_URI + assert model.env == { + "HF_MODEL_ID": HF_MODEL_ID, + "OPTION_ENGINE": "OnnxRuntime", + "HF_TASK": "text-embedding", } - mock_read_file.return_value = json.dumps(model_config) - model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sagemaker_session, - number_of_partitions=4, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - ) - - assert model.image_uri is None - mock_is_file.return_value = False - mock_path_exists.side_effect = [True, False, True] - mock_mktmp.return_value = "/tmp/dir" - expected_env = {"ENV_VAR": "env_value", "SERVING_OPTS": '"-Dai.djl.logging.level=debug"'} - mock_upload.return_value = "s3prefix" - - s3_output_uri = f"s3://{BUCKET}/partitions/" - mock_model_key_prefix.return_value = "s3prefix" - with patch("builtins.open", mock_open()) as fake_serving_properties: - model.partition(GPU_INSTANCE, s3_output_uri) - - mock_mktmp.assert_called_once_with(prefix="tmp", suffix="", dir=None) - mock_mkdir.assert_called() - assert fake_serving_properties.call_count == 2 - fake_serving_properties.assert_any_call("/tmp/dir/code/serving.properties", "w+") - fake_serving_properties.assert_any_call("/tmp/dir/code/serving.properties", "r") - mock_container_def.assert_called_once_with( - IMAGE_URI, model_data_url="s3prefix", env=expected_env - ) - - assert model.model_id == f"{s3_output_uri}aot-partitioned-checkpoints" - -@patch("sagemaker.djl_inference.model.fw_utils.model_code_key_prefix") -@patch("sagemaker.djl_inference.model._get_model_config_properties_from_s3") -@patch("sagemaker.djl_inference.model.fw_utils.tar_and_upload_dir") -def test__upload_model_to_s3__with_upload_as_tar__default_bucket_and_prefix_combinations( - tar_and_upload_dir, - _get_model_config_properties_from_s3, - model_code_key_prefix, -): - # Skip appending of timestamps that this normally does - model_code_key_prefix.side_effect = lambda a, b, c: s3_path_join(a, b, c) - def with_user_input(sess): +def test_create_djl_model_with_provided_image(sagemaker_session): + for img_uri in [LMI_IMAGE_URI, TRT_IMAGE_URI, TNX_IMAGE_URI]: model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - number_of_partitions=4, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - code_location="s3://test-bucket/test-prefix/test-prefix-2", - image_uri="image_uri", - ) - model._upload_model_to_s3(upload_as_tar=True) - args = tar_and_upload_dir.call_args.args - return "s3://%s/%s" % (args[1], args[2]) - - def without_user_input(sess): - model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - number_of_partitions=4, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - image_uri="image_uri", + model_id=VALID_UNCOMPRESSED_MODEL_DATA, + sagemaker_session=sagemaker_session, + role=ROLE, + image_uri=img_uri, ) - model._upload_model_to_s3(upload_as_tar=True) - args = tar_and_upload_dir.call_args.args - return "s3://%s/%s" % (args[1], args[2]) - - actual, expected = _test_default_bucket_and_prefix_combinations( - function_with_user_input=with_user_input, - function_without_user_input=without_user_input, - expected__without_user_input__with_default_bucket_and_default_prefix=( - f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/image_uri" - ), - expected__without_user_input__with_default_bucket_only=( - f"s3://{DEFAULT_S3_BUCKET_NAME}/image_uri" - ), - expected__with_user_input__with_default_bucket_and_prefix=( - "s3://test-bucket/test-prefix/test-prefix-2/image_uri" - ), - expected__with_user_input__with_default_bucket_only=( - "s3://test-bucket/test-prefix/test-prefix-2/image_uri" - ), - ) - assert actual == expected - - -@patch("sagemaker.djl_inference.model.fw_utils.model_code_key_prefix") -@patch("sagemaker.djl_inference.model._get_model_config_properties_from_s3") -@patch("sagemaker.djl_inference.model.S3Uploader.upload") -def test__upload_model_to_s3__without_upload_as_tar__default_bucket_and_prefix_combinations( - upload, - _get_model_config_properties_from_s3, - model_code_key_prefix, -): - """This test is similar to test__upload_model_to_s3__with_upload_as_tar__default_bucket_and_prefix_combinations - - except upload_as_tar is False and S3Uploader.upload is checked - """ - - # Skip appending of timestamps that this normally does - model_code_key_prefix.side_effect = lambda a, b, c: s3_path_join(a, b, c) + assert model.engine == "Python" + assert model.image_uri == img_uri + assert model.env == { + "HF_MODEL_ID": VALID_UNCOMPRESSED_MODEL_DATA, + "OPTION_ENGINE": "Python", + } - def with_user_input(sess): + for framework in ["djl-lmi", "djl-tensorrtllm", "djl-neuronx"]: model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - number_of_partitions=4, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - code_location="s3://test-bucket/test-prefix/test-prefix-2", - image_uri="image_uri", + model_id=VALID_UNCOMPRESSED_MODEL_DATA, + sagemaker_session=sagemaker_session, + role=ROLE, + djl_framework=framework, ) - model._upload_model_to_s3(upload_as_tar=False) - args = upload.call_args.args - return args[1] - - def without_user_input(sess): - model = DJLModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - number_of_partitions=4, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - image_uri="image_uri", + assert model.engine == "Python" + assert model.image_uri == image_uris.retrieve( + framework=framework, version=VERSION, region=REGION ) - model._upload_model_to_s3(upload_as_tar=False) - args = upload.call_args.args - return args[1] - - actual, expected = _test_default_bucket_and_prefix_combinations( - function_with_user_input=with_user_input, - function_without_user_input=without_user_input, - expected__without_user_input__with_default_bucket_and_default_prefix=( - f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/image_uri/aot-model" - ), - expected__without_user_input__with_default_bucket_only=( - f"s3://{DEFAULT_S3_BUCKET_NAME}/image_uri/aot-model" - ), - expected__with_user_input__with_default_bucket_and_prefix=( - "s3://test-bucket/test-prefix/test-prefix-2/image_uri/aot-model" - ), - expected__with_user_input__with_default_bucket_only=( - "s3://test-bucket/test-prefix/test-prefix-2/image_uri/aot-model" - ), - ) - assert actual == expected - - -@pytest.mark.parametrize( - ( - "code_location," - "expected__without_user_input__with_default_bucket_and_default_prefix, " - "expected__without_user_input__with_default_bucket_only, " - "expected__with_user_input__with_default_bucket_and_prefix, " - "expected__with_user_input__with_default_bucket_only" - ), - [ - ( - "s3://code-test-bucket/code-test-prefix/code-test-prefix-2", - "s3://code-test-bucket/code-test-prefix/code-test-prefix-2/image_uri", - "s3://code-test-bucket/code-test-prefix/code-test-prefix-2/image_uri", - "s3://test-bucket/test-prefix/test-prefix-2", - "s3://test-bucket/test-prefix/test-prefix-2", - ), - ( - None, - f"s3://{DEFAULT_S3_BUCKET_NAME}/{DEFAULT_S3_OBJECT_KEY_PREFIX_NAME}/image_uri", - f"s3://{DEFAULT_S3_BUCKET_NAME}/image_uri", - "s3://test-bucket/test-prefix/test-prefix-2", - "s3://test-bucket/test-prefix/test-prefix-2", - ), - ], -) -@patch("sagemaker.djl_inference.model.fw_utils.model_code_key_prefix") -@patch("sagemaker.djl_inference.model._get_model_config_properties_from_s3") -@patch("sagemaker.djl_inference.model.fw_utils.tar_and_upload_dir") -@patch("sagemaker.djl_inference.model._create_estimator") -def test_partition_default_bucket_and_prefix_combinations( - _create_estimator, - tar_and_upload_dir, - _get_model_config_properties_from_s3, - model_code_key_prefix, - code_location, - expected__without_user_input__with_default_bucket_and_default_prefix, - expected__without_user_input__with_default_bucket_only, - expected__with_user_input__with_default_bucket_and_prefix, - expected__with_user_input__with_default_bucket_only, -): - # Skip appending of timestamps that this normally does - model_code_key_prefix.side_effect = lambda a, b, c: s3_path_join(a, b, c) - - def with_user_input(sess): - model = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - code_location=code_location, - image_uri="image_uri", - ) - model.partition(GPU_INSTANCE, s3_output_uri="s3://test-bucket/test-prefix/test-prefix-2") - kwargs = _create_estimator.call_args.kwargs - return kwargs["s3_output_uri"] + assert model.env == { + "HF_MODEL_ID": VALID_UNCOMPRESSED_MODEL_DATA, + "OPTION_ENGINE": "Python", + } - def without_user_input(sess): - model = DeepSpeedModel( - VALID_UNCOMPRESSED_MODEL_DATA, - ROLE, - sagemaker_session=sess, - data_type="fp16", - container_log_level=logging.DEBUG, - env=ENV, - code_location=code_location, - image_uri="image_uri", - ) - model.partition(GPU_INSTANCE) - kwargs = _create_estimator.call_args.kwargs - return kwargs["s3_output_uri"] - actual, expected = _test_default_bucket_and_prefix_combinations( - function_with_user_input=with_user_input, - function_without_user_input=without_user_input, - expected__without_user_input__with_default_bucket_and_default_prefix=( - expected__without_user_input__with_default_bucket_and_default_prefix - ), - expected__without_user_input__with_default_bucket_only=expected__without_user_input__with_default_bucket_only, - expected__with_user_input__with_default_bucket_and_prefix=( - expected__with_user_input__with_default_bucket_and_prefix - ), - expected__with_user_input__with_default_bucket_only=expected__with_user_input__with_default_bucket_only, - ) - assert actual == expected +def test_create_djl_model_all_provided_args(sagemaker_session): + model = DJLModel( + model_id=HF_MODEL_ID, + sagemaker_session=sagemaker_session, + role=ROLE, + task="text-generation", + djl_framework="djl-tensorrtllm", + dtype="fp16", + tensor_parallel_degree=4, + min_workers=1, + max_workers=4, + job_queue_size=12, + parallel_loading=True, + model_loading_timeout=10, + prediction_timeout=3, + huggingface_hub_token="token", + ) + + assert model.engine == "Python" + assert model.image_uri == TRT_IMAGE_URI + assert model.env == { + "HF_MODEL_ID": HF_MODEL_ID, + "OPTION_ENGINE": "Python", + "HF_TASK": "text-generation", + "TENSOR_PARALLEL_DEGREE": "4", + "SERVING_MIN_WORKERS": "1", + "SERVING_MAX_WORKERS": "4", + "SERVING_JOB_QUEUE_SIZE": "12", + "OPTION_PARALLEL_LOADING": "True", + "OPTION_MODEL_LOADING_TIMEOUT": "10", + "OPTION_PREDICT_TIMEOUT": "3", + "HF_TOKEN": "token", + "OPTION_DTYPE": "fp16", + } diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index fd45601801..1698da3e90 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -51,6 +51,8 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.interactive_apps import SupportedInteractiveAppTypes from sagemaker.model import FrameworkModel +from sagemaker.model_card.model_card import ModelCard, ModelOverview +from sagemaker.model_card.schema_constraints import ModelCardStatusEnum from sagemaker.mxnet.estimator import MXNet from sagemaker.predictor import Predictor from sagemaker.pytorch.estimator import PyTorch @@ -72,6 +74,9 @@ DEFAULT_S3_BUCKET_NAME, DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, ) +from sagemaker.model_life_cycle import ModelLifeCycle + +from tests.unit.test_job import INSTANCE_PLACEMENT_CONFIG MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -86,6 +91,7 @@ INSTANCE_COUNT = 1 INSTANCE_TYPE = "c4.4xlarge" KEEP_ALIVE_PERIOD_IN_SECONDS = 1800 +TRAINING_PLAN = "arn:aws:sagemaker:us-west-2:336:training-plan/test_training_plan" ACCELERATOR_TYPE = "ml.eia.medium" ROLE = "DummyRole" IMAGE_URI = "fakeimage" @@ -264,6 +270,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): return MODEL_CONTAINER_DEF @@ -857,6 +864,39 @@ def test_framework_with_keep_alive_period(sagemaker_session): assert args["resource_config"]["KeepAlivePeriodInSeconds"] == KEEP_ALIVE_PERIOD_IN_SECONDS +def test_framework_with_training_plan(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + training_plan=TRAINING_PLAN, + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args["resource_config"]["TrainingPlanArn"] == TRAINING_PLAN + + +def test_framework_with_instance_placement(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_type="ml.c4.xlarge", + instance_count=2, + training_plan=TRAINING_PLAN, + instance_placement_config=INSTANCE_PLACEMENT_CONFIG, + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args["resource_config"]["InstancePlacementConfig"] == INSTANCE_PLACEMENT_CONFIG + + def test_framework_with_both_training_repository_config(sagemaker_session): f = DummyFramework( entry_point=SCRIPT_PATH, @@ -2772,7 +2812,7 @@ def test_git_support_bad_repo_url_format(sagemaker_session): ) with pytest.raises(ValueError) as error: fw.fit() - assert "Invalid Git url provided." in str(error) + assert "Unsupported URL scheme" in str(error) @patch( @@ -4336,6 +4376,17 @@ def test_register_default_image(sagemaker_session): framework_version = "2.9" nearest_model_name = "resnet50" data_input_config = '{"input_1":[1,224,224,3]}' + model_overview = ModelOverview(model_creator="TestCreator") + model_card = ModelCard( + name="TestCard", + status=ModelCardStatusEnum.DRAFT, + model_overview=model_overview, + ) + update_model_life_cycle = ModelLifeCycle( + stage="Development", + stage_status="In-Progress", + stage_description="Sending for Staging Verification", + ) estimator.register( content_types=content_types, @@ -4349,8 +4400,19 @@ def test_register_default_image(sagemaker_session): framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_config, + model_card=model_card, + model_life_cycle=update_model_life_cycle, ) sagemaker_session.create_model.assert_not_called() + exp_model_card = { + "ModelCardStatus": "Draft", + "ModelCardContent": '{"model_overview": {"model_creator": "TestCreator", "model_artifact": []}}', + } + exp_model_life_cycle = { + "Stage": "Development", + "StageStatus": "In-Progress", + "StageDescription": "Sending for Staging Verification", + } expected_create_model_package_request = { "containers": [{"Image": estimator.image_uri, "ModelDataUrl": estimator.model_data}], @@ -4362,6 +4424,8 @@ def test_register_default_image(sagemaker_session): "marketplace_cert": False, "sample_payload_url": sample_payload_url, "task": task, + "model_life_cycle": exp_model_life_cycle, + "model_card": exp_model_card, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request @@ -4388,7 +4452,7 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" - + model_card = {"ModelCardStatus": ModelCardStatusEnum.DRAFT, "ModelCardContent": "{}"} estimator.register( content_types=content_types, response_types=response_types, @@ -4411,6 +4475,7 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): "marketplace_cert": False, "sample_payload_url": sample_payload_url, "task": task, + "model_card": model_card, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request @@ -4440,6 +4505,7 @@ def test_register_inference_image(sagemaker_session): framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" + model_card = {"ModelCardStatus": ModelCardStatusEnum.DRAFT, "ModelCardContent": "{}"} estimator.register( content_types=content_types, @@ -4466,6 +4532,7 @@ def test_register_inference_image(sagemaker_session): "marketplace_cert": False, "sample_payload_url": sample_payload_url, "task": task, + "model_card": model_card, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request @@ -5243,6 +5310,7 @@ def test_all_framework_estimators_add_jumpstart_uri_tags( entry_point="inference.py", role=ROLE, tags=[{"Key": "blah", "Value": "yoyoma"}], + model_reference_arn=None, ) assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ @@ -5904,3 +5972,38 @@ def test_estimator_get_app_url_fail(sagemaker_session): f.get_app_url("fake-app") assert "does not support URL retrieval." in str(error) + + +@patch("sagemaker.mlflow.forward_sagemaker_metrics.log_sagemaker_job_to_mlflow") +def test_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + enable_network_isolation=True, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + ], + ) + + # Set environment variables restores to state after the test. + with patch.dict(os.environ, {"MLFLOW_TRACKING_URI": "test_uri"}): + f.fit("s3://mydata") + + mock_log_to_mlflow.assert_called_once() + + +@patch("sagemaker.mlflow.forward_sagemaker_metrics.log_sagemaker_job_to_mlflow") +def test_no_forward_sagemaker_metrics(mock_log_to_mlflow, sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + enable_network_isolation=False, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + ], + ) + with patch.dict(os.environ, {"MLFLOW_TRACKING_URI": "test_uri"}): + f.fit("s3://mydata") + mock_log_to_mlflow.assert_not_called() diff --git a/tests/unit/test_exception_on_bad_status.py b/tests/unit/test_exception_on_bad_status.py index 2ef017efd3..dc53c97799 100644 --- a/tests/unit/test_exception_on_bad_status.py +++ b/tests/unit/test_exception_on_bad_status.py @@ -52,7 +52,7 @@ def test_raise_when_failed_created_package(): False ), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not" except Exception as e: - assert type(e) == sagemaker.exceptions.UnexpectedStatusException + assert isinstance(e, sagemaker.exceptions.UnexpectedStatusException) assert e.actual_status == "EnRoute" assert "Completed" in e.allowed_statuses @@ -73,7 +73,7 @@ def test_does_raise_when_incorrect_job_status(): False ), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not" except Exception as e: - assert type(e) == sagemaker.exceptions.UnexpectedStatusException + assert isinstance(e, sagemaker.exceptions.UnexpectedStatusException) assert e.actual_status == "Failed" assert "Completed" in e.allowed_statuses assert "Stopped" in e.allowed_statuses @@ -92,7 +92,7 @@ def test_does_raise_capacity_error_when_incorrect_job_status(): ) assert False, "sagemaker.exceptions.CapacityError should have been raised but was not" except Exception as e: - assert type(e) == sagemaker.exceptions.CapacityError + assert isinstance(e, sagemaker.exceptions.CapacityError) assert e.actual_status == "Failed" assert "Completed" in e.allowed_statuses assert "Stopped" in e.allowed_statuses @@ -114,6 +114,6 @@ def test_raise_when_failed_to_deploy_endpoint(): False ), "sagemaker.exceptions.UnexpectedStatusException should have been raised but was not" except Exception as e: - assert type(e) == sagemaker.exceptions.UnexpectedStatusException + assert isinstance(e, sagemaker.exceptions.UnexpectedStatusException) assert e.actual_status == "Failed" assert "InService" in e.allowed_statuses diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index e955d68227..97d4e6ec2a 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -854,17 +854,14 @@ def test_validate_smdataparallel_args_raises(): # Cases {PT|TF2} # 1. None instance type - # 2. incorrect instance type - # 3. incorrect python version - # 4. incorrect framework version + # 2. incorrect python version + # 3. incorrect framework version bad_args = [ (None, "tensorflow", "2.3.1", "py3", smdataparallel_enabled), - ("ml.p3.2xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled), ("ml.p3dn.24xlarge", "tensorflow", "2.3.1", "py2", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "1.3.1", "py3", smdataparallel_enabled), (None, "pytorch", "1.6.0", "py3", smdataparallel_enabled), - ("ml.p3.2xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled), ("ml.p3dn.24xlarge", "pytorch", "1.6.0", "py2", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.5.0", "py3", smdataparallel_enabled), ] @@ -966,74 +963,6 @@ def test_validate_smdataparallel_args_not_raises(): ) -def test_validate_pytorchddp_not_raises(): - # Case 1: Framework is not PyTorch - fw_utils.validate_pytorch_distribution( - distribution=None, - framework_name="tensorflow", - framework_version="2.9.1", - py_version="py3", - image_uri="custom-container", - ) - # Case 2: Framework is PyTorch, but distribution is not PyTorchDDP - pytorchddp_disabled = {"pytorchddp": {"enabled": False}} - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_disabled, - framework_name="pytorch", - framework_version="1.10", - py_version="py3", - image_uri="custom-container", - ) - # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions - pytorchddp_enabled = {"pytorchddp": {"enabled": True}} - pytorchddp_supported_fw_versions = [ - "1.10", - "1.10.0", - "1.10.2", - "1.11", - "1.11.0", - "1.12", - "1.12.0", - "1.12.1", - "1.13.1", - "2.0.0", - "2.0.1", - "2.1.0", - "2.2.0", - ] - for framework_version in pytorchddp_supported_fw_versions: - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_enabled, - framework_name="pytorch", - framework_version=framework_version, - py_version="py3", - image_uri="custom-container", - ) - - -def test_validate_pytorchddp_raises(): - pytorchddp_enabled = {"pytorchddp": {"enabled": True}} - # Case 1: Unsupported framework version - with pytest.raises(ValueError): - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_enabled, - framework_name="pytorch", - framework_version="1.8", - py_version="py3", - image_uri=None, - ) - - # Case 2: Unsupported Py version - with pytest.raises(ValueError): - fw_utils.validate_pytorch_distribution( - distribution=pytorchddp_enabled, - framework_name="pytorch", - framework_version="1.10", - py_version="py2", - image_uri=None, - ) - - def test_validate_torch_distributed_not_raises(): # Case 1: Framework is PyTorch, but torch_distributed is not enabled torch_distributed_disabled = {"torch_distributed": {"enabled": False}} diff --git a/tests/unit/test_git_utils.py b/tests/unit/test_git_utils.py index 03bbc1ebcd..2d10ac7619 100644 --- a/tests/unit/test_git_utils.py +++ b/tests/unit/test_git_utils.py @@ -12,11 +12,12 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import pytest import os -from pathlib import Path import subprocess -from mock import patch, ANY +from pathlib import Path + +import pytest +from mock import ANY, patch from sagemaker import git_utils @@ -494,3 +495,212 @@ def test_git_clone_repo_codecommit_https_creds_not_stored_locally(tempdir, mkdte with pytest.raises(subprocess.CalledProcessError) as error: git_utils.git_clone_repo(git_config, entry_point) assert "returned non-zero exit status" in str(error.value) + + +class TestGitUrlSanitization: + """Test cases for Git URL sanitization to prevent injection attacks.""" + + def test_sanitize_git_url_valid_https_urls(self): + """Test that valid HTTPS URLs pass sanitization.""" + valid_urls = [ + "https://github.com/user/repo.git", + "https://gitlab.com/user/repo.git", + "https://token@github.com/user/repo.git", + "https://user:pass@github.com/user/repo.git", + "http://internal-git.company.com/repo.git", + ] + + for url in valid_urls: + # Should not raise any exception + result = git_utils._sanitize_git_url(url) + assert result == url + + def test_sanitize_git_url_valid_ssh_urls(self): + """Test that valid SSH URLs pass sanitization.""" + valid_urls = [ + "git@github.com:user/repo.git", + "git@gitlab.com:user/repo.git", + "ssh://git@github.com/user/repo.git", + "ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/", # 0 @ symbols - valid for ssh:// + "git@internal-git.company.com:repo.git", + ] + + for url in valid_urls: + # Should not raise any exception + result = git_utils._sanitize_git_url(url) + assert result == url + + def test_sanitize_git_url_blocks_multiple_at_https(self): + """Test that HTTPS URLs with multiple @ symbols are blocked.""" + malicious_urls = [ + "https://user@attacker.com@github.com/repo.git", + "https://token@evil.com@gitlab.com/user/repo.git", + "https://a@b@c@github.com/repo.git", + "https://user@malicious-host@github.com/legit/repo.git", + ] + + for url in malicious_urls: + with pytest.raises(ValueError) as error: + git_utils._sanitize_git_url(url) + assert "multiple @ symbols detected" in str(error.value) + + def test_sanitize_git_url_blocks_multiple_at_ssh(self): + """Test that SSH URLs with multiple @ symbols are blocked.""" + malicious_urls = [ + "git@attacker.com@github.com:repo.git", + "git@evil@gitlab.com:user/repo.git", + "ssh://git@malicious@github.com/repo.git", + "git@a@b@c:repo.git", + ] + + for url in malicious_urls: + with pytest.raises(ValueError) as error: + git_utils._sanitize_git_url(url) + # git@ URLs should give "exactly one @ symbol" error + # ssh:// URLs should give "multiple @ symbols detected" error + assert any( + phrase in str(error.value) + for phrase in ["multiple @ symbols detected", "exactly one @ symbol"] + ) + + def test_sanitize_git_url_blocks_invalid_schemes_and_git_at_format(self): + """Test that invalid schemes and git@ format violations are blocked.""" + # Test unsupported schemes + unsupported_scheme_urls = [ + "git-github.com:user/repo.git", # Doesn't start with git@, ssh://, http://, https:// + ] + + for url in unsupported_scheme_urls: + with pytest.raises(ValueError) as error: + git_utils._sanitize_git_url(url) + assert "Unsupported URL scheme" in str(error.value) + + # Test git@ URLs with wrong @ count + invalid_git_at_urls = [ + "git@github.com@evil.com:repo.git", # 2 @ symbols + ] + + for url in invalid_git_at_urls: + with pytest.raises(ValueError) as error: + git_utils._sanitize_git_url(url) + assert "exactly one @ symbol" in str(error.value) + + def test_sanitize_git_url_blocks_url_encoding_obfuscation(self): + """Test that URL-encoded obfuscation attempts are blocked.""" + obfuscated_urls = [ + "https://github.com%25evil.com/repo.git", + "https://user@github.com%40attacker.com/repo.git", + "https://github.com%2Fevil.com/repo.git", + "https://github.com%3Aevil.com/repo.git", + ] + + for url in obfuscated_urls: + with pytest.raises(ValueError) as error: + git_utils._sanitize_git_url(url) + # The error could be either suspicious encoding or invalid characters + assert any( + phrase in str(error.value) + for phrase in ["Suspicious URL encoding detected", "Invalid characters in hostname"] + ) + + def test_sanitize_git_url_blocks_invalid_hostname_chars(self): + """Test that hostnames with invalid characters are blocked.""" + invalid_urls = [ + "https://github", + ] + + for url in unsupported_urls: + with pytest.raises(ValueError) as error: + git_utils._sanitize_git_url(url) + assert "Unsupported URL scheme" in str(error.value) + + def test_git_clone_repo_blocks_malicious_https_url(self): + """Test that git_clone_repo blocks malicious HTTPS URLs.""" + malicious_git_config = { + "repo": "https://user@attacker.com@github.com/legit/repo.git", + "branch": "main", + } + entry_point = "train.py" + + with pytest.raises(ValueError) as error: + git_utils.git_clone_repo(malicious_git_config, entry_point) + assert "multiple @ symbols detected" in str(error.value) + + def test_git_clone_repo_blocks_malicious_ssh_url(self): + """Test that git_clone_repo blocks malicious SSH URLs.""" + malicious_git_config = { + "repo": "git@OBVIOUS@github.com:sage-maker/temp-sev2.git", + "branch": "main", + } + entry_point = "train.py" + + with pytest.raises(ValueError) as error: + git_utils.git_clone_repo(malicious_git_config, entry_point) + assert "exactly one @ symbol" in str(error.value) + + def test_git_clone_repo_blocks_url_encoded_attack(self): + """Test that git_clone_repo blocks URL-encoded attacks.""" + malicious_git_config = { + "repo": "https://github.com%40attacker.com/repo.git", + "branch": "main", + } + entry_point = "train.py" + + with pytest.raises(ValueError) as error: + git_utils.git_clone_repo(malicious_git_config, entry_point) + assert "Suspicious URL encoding detected" in str(error.value) + + def test_sanitize_git_url_comprehensive_attack_scenarios(self): + attack_scenarios = [ + # Original PoC attack + "https://USER@YOUR_NGROK_OR_LOCALHOST/malicious.git@github.com%25legit%25repo.git", + # Variations of the attack + "https://user@malicious-host@github.com/legit/repo.git", + "git@attacker.com@github.com:user/repo.git", + "ssh://git@evil.com@github.com/repo.git", + # URL encoding variations + "https://github.com%40evil.com/repo.git", + "https://user@github.com%2Fevil.com/repo.git", + ] + + entry_point = "train.py" + + for malicious_url in attack_scenarios: + git_config = {"repo": malicious_url} + with pytest.raises(ValueError) as error: + git_utils.git_clone_repo(git_config, entry_point) + # Should be blocked by sanitization + assert any( + phrase in str(error.value) + for phrase in [ + "multiple @ symbols detected", + "exactly one @ symbol", + "Suspicious URL encoding detected", + "Invalid characters in hostname", + ] + ) diff --git a/tests/unit/test_hyperparameter.py b/tests/unit/test_hyperparameter.py index ba7a363c40..edb2de97ee 100644 --- a/tests/unit/test_hyperparameter.py +++ b/tests/unit/test_hyperparameter.py @@ -62,7 +62,7 @@ def test_validated(): def test_data_type(): x = Test() x.validated = 66 - assert type(x.validated) == Test.__dict__["validated"].data_type + assert isinstance(x.validated, Test.__dict__["validated"].data_type) def test_from_string(): diff --git a/tests/unit/test_inputs.py b/tests/unit/test_inputs.py index 7d9c2b2c2f..133c31eb75 100644 --- a/tests/unit/test_inputs.py +++ b/tests/unit/test_inputs.py @@ -41,6 +41,8 @@ def test_training_input_all_arguments(): record_wrapping = "RecordIO" s3_data_type = "Manifestfile" input_mode = "Pipe" + hub_access_config = {"HubContentArn": "some-hub-content-arn"} + model_access_config = {"AcceptEula": True} result = TrainingInput( s3_data=prefix, distribution=distribution, @@ -49,6 +51,8 @@ def test_training_input_all_arguments(): content_type=content_type, record_wrapping=record_wrapping, s3_data_type=s3_data_type, + hub_access_config=hub_access_config, + model_access_config=model_access_config, ) expected = { "DataSource": { @@ -56,6 +60,8 @@ def test_training_input_all_arguments(): "S3DataDistributionType": distribution, "S3DataType": s3_data_type, "S3Uri": prefix, + "ModelAccessConfig": model_access_config, + "HubAccessConfig": hub_access_config, } }, "CompressionType": compression, @@ -76,6 +82,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): s3_data_type = "Manifestfile" instance_groups = ["data-server"] input_mode = "Pipe" + hub_access_config = {"HubContentArn": "some-hub-content-arn"} + model_access_config = {"AcceptEula": True} result = TrainingInput( s3_data=prefix, distribution=distribution, @@ -85,6 +93,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): record_wrapping=record_wrapping, s3_data_type=s3_data_type, instance_groups=instance_groups, + hub_access_config=hub_access_config, + model_access_config=model_access_config, ) expected = { @@ -94,6 +104,8 @@ def test_training_input_all_arguments_heterogeneous_cluster(): "S3DataType": s3_data_type, "S3Uri": prefix, "InstanceGroupNames": instance_groups, + "ModelAccessConfig": model_access_config, + "HubAccessConfig": hub_access_config, } }, "CompressionType": compression, diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 603b494e5a..cdd4a2630e 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -31,6 +31,11 @@ INSTANCE_COUNT = 1 INSTANCE_TYPE = "c4.4xlarge" KEEP_ALIVE_PERIOD = 1800 +TRAINING_PLAN = "arn:aws:sagemaker:us-west-2:336:training-plan/test_training_plan" +INSTANCE_PLACEMENT_CONFIG = { + "EnableMultipleJobs": True, + "PlacementSpecifications": [{"UltraServerId": "us-1", "InstanceCount": "2"}], +} INSTANCE_GROUP = InstanceGroup("group", "ml.c4.xlarge", 1) VOLUME_SIZE = 1 MAX_RUNTIME = 1 @@ -205,6 +210,32 @@ def test_load_config_with_model_channel_no_inputs(estimator): assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME +def test_load_config_with_access_configs(estimator): + estimator.model_uri = MODEL_URI + estimator.model_channel_name = MODEL_CHANNEL_NAME + estimator.model_access_config = {"AcceptEula": True} + estimator.hub_access_config = {"HubContentArn": "dummy_arn"} + + config = _Job._load_config(inputs=None, estimator=estimator) + assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == MODEL_URI + assert config["input_config"][0]["ChannelName"] == MODEL_CHANNEL_NAME + assert config["role"] == ROLE + assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH + assert "KmsKeyId" not in config["output_config"] + assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT + assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE + assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE + assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME + assert ( + config["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == estimator.model_access_config + ) + assert ( + config["input_config"][0]["DataSource"]["S3DataSource"]["HubAccessConfig"] + == estimator.hub_access_config + ) + + def test_load_config_with_code_channel(framework): inputs = TrainingInput(BUCKET_NAME) @@ -346,20 +377,43 @@ def test_format_record_set_list_input(): @pytest.mark.parametrize( - "channel_uri, channel_name, content_type, input_mode", + "channel_uri, channel_name, content_type, input_mode, model_access_config, hub_access_config", [ - [MODEL_URI, MODEL_CHANNEL_NAME, "application/x-sagemaker-model", "File"], - [CODE_URI, CODE_CHANNEL_NAME, None, None], + [ + MODEL_URI, + MODEL_CHANNEL_NAME, + "application/x-sagemaker-model", + "File", + {"AcceptEula": True}, + None, + ], + [CODE_URI, CODE_CHANNEL_NAME, None, None, None, {"HubContentArn": "dummy_arn"}], ], ) -def test_prepare_channel(channel_uri, channel_name, content_type, input_mode): +def test_prepare_channel( + channel_uri, channel_name, content_type, input_mode, model_access_config, hub_access_config +): channel = _Job._prepare_channel( - [], channel_uri, channel_name, content_type=content_type, input_mode=input_mode + [], + channel_uri, + channel_name, + content_type=content_type, + input_mode=input_mode, + model_access_config=model_access_config, + hub_access_config=hub_access_config, ) assert channel["DataSource"]["S3DataSource"]["S3Uri"] == channel_uri assert channel["DataSource"]["S3DataSource"]["S3DataDistributionType"] == "FullyReplicated" assert channel["DataSource"]["S3DataSource"]["S3DataType"] == "S3Prefix" + if hub_access_config: + assert channel["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + else: + assert "HubAccessConfig" not in channel["DataSource"]["S3DataSource"] + if model_access_config: + assert channel["DataSource"]["S3DataSource"]["ModelAccessConfig"] == model_access_config + else: + assert "ModelAccessConfig" not in channel["DataSource"]["S3DataSource"] assert channel["ChannelName"] == channel_name assert "CompressionType" not in channel assert "RecordWrapperType" not in channel @@ -545,6 +599,23 @@ def test_format_string_uri_input_string(): assert s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] == inputs +def test_format_string_uri_input_string_with_access_configs(): + inputs = BUCKET_NAME + model_access_config = {"AcceptEula": True} + hub_access_config = {"HubContentArn": "dummy_arn"} + + s3_uri_input = _Job._format_string_uri_input( + inputs, model_access_config=model_access_config, hub_access_config=hub_access_config + ) + + assert s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] == inputs + assert s3_uri_input.config["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == model_access_config + ) + + def test_format_string_uri_file_system_input(): file_system_id = "fs-fd85e556" file_system_type = "EFS" @@ -584,6 +655,26 @@ def test_format_string_uri_input(): ) +def test_format_string_uri_input_with_access_configs(): + inputs = TrainingInput(BUCKET_NAME) + model_access_config = {"AcceptEula": True} + hub_access_config = {"HubContentArn": "dummy_arn"} + + s3_uri_input = _Job._format_string_uri_input( + inputs, model_access_config=model_access_config, hub_access_config=hub_access_config + ) + + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["S3Uri"] + == inputs.config["DataSource"]["S3DataSource"]["S3Uri"] + ) + assert s3_uri_input.config["DataSource"]["S3DataSource"]["HubAccessConfig"] == hub_access_config + assert ( + s3_uri_input.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] + == model_access_config + ) + + def test_format_string_uri_input_exception(): inputs = 1 @@ -633,19 +724,73 @@ def test_prepare_output_config_kms_key_none(): def test_prepare_resource_config(): resource_config = _Job._prepare_resource_config( - INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, None, None + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + None, + None, + None, + ) + + assert resource_config == { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": VOLUME_SIZE, + } + + +def test_prepare_resource_config_with_training_plan(): + resource_config = _Job._prepare_resource_config( + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + VOLUME_KMS_KEY, + None, + TRAINING_PLAN, + ) + + assert resource_config == { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + "VolumeSizeInGB": VOLUME_SIZE, + "VolumeKmsKeyId": VOLUME_KMS_KEY, + "TrainingPlanArn": TRAINING_PLAN, + } + + +def test_prepare_resource_config_with_placement_config(): + resource_config = _Job._prepare_resource_config( + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + VOLUME_KMS_KEY, + None, + TRAINING_PLAN, + INSTANCE_PLACEMENT_CONFIG, ) assert resource_config == { "InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE, "VolumeSizeInGB": VOLUME_SIZE, + "VolumeKmsKeyId": VOLUME_KMS_KEY, + "TrainingPlanArn": TRAINING_PLAN, + "InstancePlacementConfig": INSTANCE_PLACEMENT_CONFIG, } def test_prepare_resource_config_with_keep_alive_period(): resource_config = _Job._prepare_resource_config( - INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, KEEP_ALIVE_PERIOD + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + VOLUME_KMS_KEY, + KEEP_ALIVE_PERIOD, + None, ) assert resource_config == { @@ -659,7 +804,13 @@ def test_prepare_resource_config_with_keep_alive_period(): def test_prepare_resource_config_with_volume_kms(): resource_config = _Job._prepare_resource_config( - INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY, None + INSTANCE_COUNT, + INSTANCE_TYPE, + None, + VOLUME_SIZE, + VOLUME_KMS_KEY, + None, + None, ) assert resource_config == { @@ -678,6 +829,7 @@ def test_prepare_resource_config_with_heterogeneous_cluster(): VOLUME_SIZE, None, None, + None, ) assert resource_config == { @@ -698,6 +850,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou VOLUME_SIZE, None, None, + None, ) assert "instance_count and instance_type cannot be set when instance_groups is set" in str( error @@ -713,6 +866,7 @@ def test_prepare_resource_config_with_instance_groups_instance_type_instance_cou VOLUME_SIZE, None, None, + None, ) assert "instance_count and instance_type must be set if instance_groups is not set" in str( error diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 4a584dfae4..2c47356921 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -360,7 +360,7 @@ def test_mxnet( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names actual_train_args = sagemaker_session.method_calls[0][2] job_name = actual_train_args["job_name"] diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index b546d4e9e8..07d419779f 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -420,3 +420,27 @@ def test_network_isolation(tfo, time, sagemaker_session): vpc_config=None, enable_network_isolation=True, ) + + +def test_pipeline_model_register(sagemaker_session): + sagemaker_session.create_model_package_from_containers = Mock( + name="create_model_package_from_containers", + return_value={ + "ModelPackageArn": "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" + }, + ) + framework_model = DummyFrameworkModel(sagemaker_session) + sparkml_model = SparkMLModel( + model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session + ) + model = PipelineModel( + models=[framework_model, sparkml_model], + role=ROLE, + sagemaker_session=sagemaker_session, + enable_network_isolation=True, + ) + model_package = model.register() + assert ( + model_package.model_package_arn + == "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" + ) diff --git a/tests/unit/test_predictor_async.py b/tests/unit/test_predictor_async.py index fa2d6da6c7..c9f12ff023 100644 --- a/tests/unit/test_predictor_async.py +++ b/tests/unit/test_predictor_async.py @@ -233,7 +233,7 @@ def test_async_predict_call_verify_exceptions(): with pytest.raises( PollingTimeoutError, match=f"No result at {ASYNC_OUTPUT_LOCATION} after polling for " - f"{DEFAULT_WAITER_CONFIG.delay*DEFAULT_WAITER_CONFIG.max_attempts}" + f"{DEFAULT_WAITER_CONFIG.delay * DEFAULT_WAITER_CONFIG.max_attempts}" f" seconds. Inference could still be running", ): predictor_async.predict(input_path=input_location, waiter_config=DEFAULT_WAITER_CONFIG) @@ -253,7 +253,7 @@ def test_async_predict_call_verify_exceptions_with_null_failure_path(): with pytest.raises( PollingTimeoutError, match=f"No result at {ASYNC_OUTPUT_LOCATION} after polling for " - f"{DEFAULT_WAITER_CONFIG.delay*DEFAULT_WAITER_CONFIG.max_attempts}" + f"{DEFAULT_WAITER_CONFIG.delay * DEFAULT_WAITER_CONFIG.max_attempts}" f" seconds. Inference could still be running", ): predictor_async.predict(input_path=input_location, waiter_config=DEFAULT_WAITER_CONFIG) diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 93e3d91f87..06d2cde02e 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import copy +from textwrap import dedent import pytest from mock import Mock, patch, MagicMock @@ -1102,6 +1103,137 @@ def test_pyspark_processor_configuration_path_pipeline_config( ) +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_get_codeartifact_command(pipeline_session): + codeartifact_repo_arn = ( + "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" + ) + + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + codeartifact_command = processor._get_codeartifact_command( + codeartifact_repo_arn=codeartifact_repo_arn + ) + + assert ( + codeartifact_command + == "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" # noqa: E501 # pylint: disable=line-too-long + ) + + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_get_codeartifact_command_bad_repo_arn(pipeline_session): + codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain" + + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + with pytest.raises(ValueError): + processor._get_codeartifact_command(codeartifact_repo_arn=codeartifact_repo_arn) + + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_generate_framework_script(pipeline_session): + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + framework_script = processor._generate_framework_script(user_script="process.py") + + assert framework_script == dedent( + """\ + #!/bin/bash + + cd /opt/ml/processing/input/code/ + tar -xzf sourcedir.tar.gz + + # Exit on any error. SageMaker uses error code to mark failed job. + set -e + + if [[ -f 'requirements.txt' ]]; then + # Optionally log into CodeArtifact + if ! hash aws 2>/dev/null; then + echo "AWS CLI is not installed. Skipping CodeArtifact login." + else + echo 'CodeArtifact repository not specified. Skipping login.' + fi + + # Some py3 containers has typing, which may breaks pip install + pip uninstall --yes typing + + pip install -r requirements.txt + fi + + python process.py "$@" + """ + ) + + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_generate_framework_script_with_codeartifact(pipeline_session): + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + framework_script = processor._generate_framework_script( + user_script="process.py", + codeartifact_repo_arn=( + "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" + ), + ) + + assert framework_script == dedent( + """\ + #!/bin/bash + + cd /opt/ml/processing/input/code/ + tar -xzf sourcedir.tar.gz + + # Exit on any error. SageMaker uses error code to mark failed job. + set -e + + if [[ -f 'requirements.txt' ]]; then + # Optionally log into CodeArtifact + if ! hash aws 2>/dev/null; then + echo "AWS CLI is not installed. Skipping CodeArtifact login." + else + aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2 + fi + + # Some py3 containers has typing, which may breaks pip install + pip uninstall --yes typing + + pip install -r requirements.txt + fi + + python process.py "$@" + """ # noqa: E501 # pylint: disable=line-too-long + ) + + def _get_script_processor(sagemaker_session): return ScriptProcessor( role=ROLE, diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 5ada026ef8..8352f3090b 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -18,10 +18,15 @@ import pytest from mock import ANY, MagicMock, Mock, patch from packaging.version import Version +import tempfile from sagemaker import image_uris from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel +from sagemaker.pytorch.estimator import ( + _get_training_recipe_image_uri, + _get_training_recipe_gpu_script, +) from sagemaker.instance_group import InstanceGroup from sagemaker.session_settings import SessionSettings @@ -35,6 +40,8 @@ BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 INSTANCE_TYPE = "ml.c4.4xlarge" +INSTANCE_TYPE_GPU = "ml.p4d.24xlarge" +INSTANCE_TYPE_TRAINIUM = "ml.trn1.32xlarge" ACCELERATOR_TYPE = "ml.eia.medium" IMAGE_URI = "sagemaker-pytorch" JOB_NAME = "{}-{}".format(IMAGE_URI, TIMESTAMP) @@ -59,6 +66,18 @@ } DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} +NEURON_RECIPE = ( + "https://raw.githubusercontent.com/aws-neuron/" + "neuronx-distributed-training/refs/heads/main/examples/" + "conf/hf_llama3_8B_config.yaml" +) +RECIPE_GPU_IMAGE = ( + "658645717510.dkr.ecr.us-west-2.amazonaws.com/smdistributed-modelparallel:2.4.1-gpu-py311" +) +RECIPE_NEURON_IMAGE = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training-neuronx:2.1.2-neuronx-py310-sdk2.20.2-ubuntu20.04" +) @pytest.fixture(name="sagemaker_session") @@ -337,7 +356,7 @@ def test_pytorch( sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job(pytorch_inference_version, pytorch_inference_py_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs @@ -801,14 +820,15 @@ def test_pytorch_ddp_distribution_configuration( distribution=pytorch.distribution ) expected_torch_ddp = { - "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_distributed_dataparallel_enabled": True, + "sagemaker_distributed_dataparallel_custom_mpi_options": "", "sagemaker_instance_type": test_instance_type, } assert actual_pytorch_ddp == expected_torch_ddp def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session): - unsupported_framework_version = "1.9.1" + unsupported_framework_version = "1.5.0" unsupported_py_version = "py2" with pytest.raises(ValueError) as error: _pytorch_estimator( @@ -825,3 +845,317 @@ def test_predictor_with_component_name(sagemaker_session, component_name): predictor = PyTorchPredictor("endpoint", sagemaker_session, component_name=component_name) assert predictor._get_component_name() == component_name + + +def test_training_recipe_for_cpu(sagemaker_session): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output/tensorboard", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + }, + } + + with pytest.raises(ValueError): + PyTorch( + output_path="s3://output_path", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + base_job_name="job", + container_log_level=container_log_level, + training_recipe="training/llama/hf_llama3_8b_seq8k_gpu_p5x16_pretrain", + recipe_overrides=recipe_overrides, + ) + + +@pytest.mark.parametrize( + "recipe, model", + [ + ("hf_llama3_8b_seq8k_gpu_p5x16_pretrain", "llama"), + ("hf_mistral_7b_seq8k_gpu_p5x16_pretrain", "mistral"), + ("hf_mixtral_8x7b_seq8k_gpu_p5x16_pretrain", "mixtral"), + ], +) +def test_training_recipe_for_gpu(sagemaker_session, recipe, model): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + }, + } + pytorch = PyTorch( + output_path="s3://output_path", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + base_job_name="job", + container_log_level=container_log_level, + training_recipe=f"training/{model}/{recipe}", + recipe_overrides=recipe_overrides, + ) + + assert pytorch.source_dir == "." + assert pytorch.entry_point == f"{model}_pretrain.py" + expected_distribution = { + "torch_distributed": { + "enabled": True, + }, + "smdistributed": { + "modelparallel": { + "enabled": True, + "parameters": { + "placement_strategy": "cluster", + }, + }, + }, + } + assert pytorch.distribution.items() == expected_distribution.items() + + +def test_training_recipe_with_override(sagemaker_session): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + "model_type": "mistral", + }, + } + pytorch = PyTorch( + output_path="s3://output_path", + role=ROLE, + image_uri=IMAGE_URI, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + base_job_name="job", + container_log_level=container_log_level, + training_recipe="training/llama/hf_llama3_8b_seq8k_gpu_p5x16_pretrain", + recipe_overrides=recipe_overrides, + ) + + assert pytorch.source_dir == "." + assert pytorch.entry_point == "mistral_pretrain.py" + assert pytorch.image_uri == IMAGE_URI + + +def test_training_recipe_gpu_custom_source_dir(sagemaker_session): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output", + "checkpoint_dir": "/opt/ml/checkpoints", + }, + "model": { + "data": { + "train_dir": "/opt/ml/input/data/train", + "val_dir": "/opt/ml/input/data/val", + }, + "model_type": "mistral", + }, + } + source_dir = tempfile.TemporaryDirectory(prefix="source_") + pytorch = PyTorch( + output_path="s3://output_path", + role=ROLE, + image_uri=IMAGE_URI, + source_dir=source_dir.name, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + base_job_name="job", + container_log_level=container_log_level, + training_recipe="training/llama/hf_llama3_8b_seq8k_gpu_p5x16_pretrain", + recipe_overrides=recipe_overrides, + ) + + assert pytorch.source_dir == source_dir.name + assert pytorch.entry_point == "mistral_pretrain.py" + assert pytorch.image_uri == IMAGE_URI + + +def test_training_recipe_for_trainium(sagemaker_session): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output", + }, + "data": { + "train_dir": "/opt/ml/input/data/train", + }, + "model": { + "model_config": "/opt/ml/input/data/train/config.json", + }, + "compiler_cache_url": "s3://s3://output_path/neuron-cache", + } + pytorch = PyTorch( + output_path="s3://output_path", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_TRAINIUM, + base_job_name="job", + container_log_level=container_log_level, + training_recipe=NEURON_RECIPE, + recipe_overrides=recipe_overrides, + ) + + assert pytorch.source_dir == "." + assert pytorch.entry_point == "training_orchestrator.py" + expected_distribution = { + "torch_distributed": { + "enabled": True, + }, + } + assert pytorch.distribution == expected_distribution + + +@pytest.mark.parametrize( + "test_case", + [ + { + "script": "llama_pretrain.py", + "recipe": { + "model": { + "model_type": "llama_v3", + }, + }, + }, + { + "script": "mistral_pretrain.py", + "recipe": { + "model": { + "model_type": "mistral", + }, + }, + }, + { + "script": "deepseek_pretrain.py", + "recipe": { + "model": { + "model_type": "deepseek_llamav3", + }, + }, + }, + { + "script": "deepseek_pretrain.py", + "recipe": { + "model": { + "model_type": "deepseek_qwenv2", + }, + }, + }, + { + "script": "custom_pretrain.py", + "recipe": { + "model": { + "model_type": "gpt_oss", + }, + }, + }, + ], +) +@patch("shutil.copyfile") +def test_get_training_recipe_gpu_script(mock_copyfile, test_case): + script = test_case["script"] + recipe = test_case["recipe"] + mock_copyfile.return_value = None + + assert _get_training_recipe_gpu_script("code_dir", recipe, "source_dir") == script + + +def test_training_recipe_for_trainium_custom_source_dir(sagemaker_session): + container_log_level = '"logging.INFO"' + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "exp_manager": { + "explicit_log_dir": "/opt/ml/output", + }, + "data": { + "train_dir": "/opt/ml/input/data/train", + }, + "model": { + "model_config": "/opt/ml/input/data/train/config.json", + }, + "compiler_cache_url": "s3://s3://output_path/neuron-cache", + } + source_dir = tempfile.TemporaryDirectory(prefix="source_") + pytorch = PyTorch( + output_path="s3://output_path", + role=ROLE, + source_dir=source_dir.name, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_TRAINIUM, + base_job_name="job", + container_log_level=container_log_level, + training_recipe=NEURON_RECIPE, + recipe_overrides=recipe_overrides, + ) + + assert pytorch.source_dir == source_dir.name + assert pytorch.entry_point == "training_orchestrator.py" + expected_distribution = { + "torch_distributed": { + "enabled": True, + }, + } + assert pytorch.distribution == expected_distribution + + +def test_training_recipe_images_uri(): + gpu_image_cfg = {"framework": "pytorch-smp", "version": "2.4.1", "additional_args": {}} + gpu_image_uri = _get_training_recipe_image_uri(gpu_image_cfg, "us-west-2") + assert gpu_image_uri == RECIPE_GPU_IMAGE + neuron_image_cfg = { + "framework": "hyperpod-recipes-neuron", + "version": "2.1.2", + "additional_args": {}, + } + neuron_image_uri = _get_training_recipe_image_uri(neuron_image_cfg, "us-west-2") + assert neuron_image_uri == RECIPE_NEURON_IMAGE diff --git a/tests/unit/test_pytorch_nova.py b/tests/unit/test_pytorch_nova.py new file mode 100644 index 0000000000..f78bdcae7d --- /dev/null +++ b/tests/unit/test_pytorch_nova.py @@ -0,0 +1,753 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +import tempfile +from mock import Mock, patch +from omegaconf import OmegaConf + +from sagemaker.estimator import EstimatorBase + +from sagemaker.pytorch import PyTorch +from sagemaker.pytorch.estimator import ( + _is_nova_recipe, + _device_get_distribution, +) +from sagemaker.inputs import TrainingInput +from sagemaker.session_settings import SessionSettings + +# Constants for testing +ROLE = "Dummy" +REGION = "us-west-2" +BUCKET_NAME = "mybucket" +INSTANCE_COUNT = 1 +INSTANCE_TYPE = "ml.c4.4xlarge" +INSTANCE_TYPE_GPU = "ml.p4d.24xlarge" +IMAGE_URI = "sagemaker-pytorch" + + +@pytest.fixture(name="sagemaker_session") +def fixture_sagemaker_session(): + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + s3_resource=None, + s3_client=None, + settings=SessionSettings(), + default_bucket_prefix=None, + ) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + session.expand_role = Mock(name="expand_role", return_value=ROLE) + session.upload_data = Mock(return_value="s3://mybucket/recipes/nova-recipe.yaml") + session.sagemaker_config = {} + return session + + +def test_is_nova_recipe(): + """Test that _is_nova_recipe correctly identifies Nova recipes.""" + # Valid Nova recipe + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foo-bar", + "model_name_or_path": "foo-bar/foo-bar123", + } + } + ) + assert _is_nova_recipe(recipe) is True + + # Not a Nova recipe - missing run section + recipe = OmegaConf.create( + { + "trainer": { + "model_type": "amazon.nova.foo-bar", + "model_name_or_path": "foo-bar/foo-bar123", + } + } + ) + assert _is_nova_recipe(recipe) is False + + # Not a Nova recipe - wrong model_type + recipe = OmegaConf.create( + {"run": {"model_type": "foo-bar3", "model_name_or_path": "foo-bar/foo-bar123"}} + ) + assert _is_nova_recipe(recipe) is False + + # Not a Nova recipe - missing model_name_or_path + recipe = OmegaConf.create({"run": {"model_type": "amazon.nova.foo-bar"}}) + assert _is_nova_recipe(recipe) is False + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_with_model_name(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe correctly sets up hyperparameters for Nova recipes with model name.""" + # Create a mock recipe + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "replicas": 4, + } + } + ) + + # Setup the expected return value + expected_args = { + "hyperparameters": {"base_model": "foobar/foobar-3-8b"}, + "entry_point": None, + "source_dir": None, + "distribution": {}, + "default_image_uri": IMAGE_URI, + } + + # Mock the _setup_for_nova_recipe method + with patch( + "sagemaker.pytorch.estimator.PyTorch._setup_for_nova_recipe", return_value=expected_args + ) as mock_nova_setup: + # Create the PyTorch estimator with mocked _recipe_load + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + # Mock _recipe_resolve_and_save to return our recipe + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_recipe is True + + # Verify _setup_for_nova_recipe was called + mock_nova_setup.assert_called_once() + call_args = mock_nova_setup.call_args + assert len(call_args[0]) >= 2 # Check that at least recipe and recipe_name were passed + assert call_args[0][0] == recipe # first arg should be recipe + assert call_args[0][1] == "nova_recipe" # second arg should be recipe_name + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_with_s3_path(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe correctly sets up hyperparameters for Nova recipes with S3 path.""" + # Create a mock recipe with S3 path + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "s3://mybucket/models/foobar3", + "replicas": 4, + } + } + ) + + # Setup the expected return value + expected_args = { + "hyperparameters": {"base_model_location": "s3://mybucket/models/foobar3"}, + "entry_point": None, + "source_dir": None, + "distribution": {}, + "default_image_uri": IMAGE_URI, + } + + # Mock the _setup_for_nova_recipe method + with patch( + "sagemaker.pytorch.estimator.PyTorch._setup_for_nova_recipe", return_value=expected_args + ) as mock_nova_setup: + # Create the PyTorch estimator with mocked _recipe_load + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + # Mock _recipe_resolve_and_save to return our recipe + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_recipe is True + + # Verify _setup_for_nova_recipe was called + mock_nova_setup.assert_called_once() + + # Verify that hyperparameters were set correctly + assert ( + pytorch._hyperparameters.get("base_model_location") + == "s3://mybucket/models/foobar3" + ) + + +def test_device_handle_instance_count_with_nova_replicas(): + """Test that _device_handle_instance_count correctly gets instance_count from Nova recipe replicas.""" + # Create mock recipe with replicas + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "replicas": 4, + } + } + ) + + # Test with no instance_count in kwargs + kwargs = {} + PyTorch._device_handle_instance_count(kwargs, recipe) + assert kwargs["instance_count"] == 4 + + +def test_device_handle_instance_count_with_nova_no_replicas(): + """Test that _device_handle_instance_count raises an error when no instance_count or replicas are provided.""" + # Create mock recipe without replicas + recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Test with no instance_count in kwargs + kwargs = {} + with pytest.raises(ValueError) as error: + PyTorch._device_handle_instance_count(kwargs, recipe) + + assert "Must set either instance_count argument for estimator or" in str(error) + + +@patch("sagemaker.pytorch.estimator.logger.warning") +def test_device_handle_instance_count_with_nova_both_provided(mock_warning): + """Test that _device_handle_instance_count warns when both instance_count and replicas are provided.""" + # Create mock recipe with replicas + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "replicas": 4, + } + } + ) + + # Test with instance_count in kwargs + kwargs = {"instance_count": 2} + PyTorch._device_handle_instance_count(kwargs, recipe) + + # Verify warning was logged + mock_warning.assert_called_with( + "Using instance_count argument to estimator to set number " + "of nodes. Ignoring run -> replicas in recipe." + ) + + # Verify instance_count wasn't changed + assert kwargs["instance_count"] == 2 + + +def test_device_validate_and_get_type_with_nova(): + """Test that _device_validate_and_get_type works correctly with Nova recipes.""" + # Create mock recipe + recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Test with GPU instance type + kwargs = {"instance_type": INSTANCE_TYPE_GPU} + device_type = PyTorch._device_validate_and_get_type(kwargs, recipe) + assert device_type == "gpu" + + # Test with CPU instance type + kwargs = {"instance_type": INSTANCE_TYPE} + device_type = PyTorch._device_validate_and_get_type(kwargs, recipe) + assert device_type == "cpu" + + +def test_device_validate_and_get_type_no_instance_type(): + """Test that _device_validate_and_get_type raises an error when no instance_type is provided.""" + # Create mock recipe + recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Test with no instance_type + kwargs = {} + with pytest.raises(ValueError) as error: + PyTorch._device_validate_and_get_type(kwargs, recipe) + + assert "Must pass instance type to estimator" in str(error) + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("time.time", return_value=1714500000) # May 1, 2024 +def test_upload_recipe_to_s3(mock_time, mock_recipe_load, sagemaker_session): + """Test that _upload_recipe_to_s3 correctly uploads the recipe file to S3.""" + # Create a mock recipe that will be identified as a Nova recipe + mock_recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("nova_recipe", mock_recipe) + + # Setup + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="nova_recipe", + ) + + # Set Nova recipe attributes + pytorch.is_nova_recipe = True + + # Create a temporary file to use as the recipe file + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + # Test uploading the recipe file to S3 + s3_uri = pytorch._upload_recipe_to_s3(sagemaker_session, temp_file.name) + + # Verify the upload_data method was called with the correct parameters + sagemaker_session.upload_data.assert_called_once() + + # Check that the S3 URI is returned correctly + assert s3_uri == sagemaker_session.upload_data.return_value + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("tempfile.NamedTemporaryFile") +@patch("omegaconf.OmegaConf.save") +@patch("sagemaker.pytorch.estimator._try_resolve_recipe") +def test_recipe_resolve_and_save( + mock_try_resolve, mock_save, mock_temp_file, mock_recipe_load, sagemaker_session +): + """Test that _recipe_resolve_and_save correctly resolves an`d saves the recipe.""" + # Create a mock recipe that will be identified as a Nova recipe + mock_recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("nova_recipe", mock_recipe) + + # Setup + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="nova_recipe", + ) + + # Set Nova recipe attributes + pytorch.is_nova_recipe = True + + # Mock the temporary file + mock_temp_file_instance = Mock() + mock_temp_file_instance.name = "/tmp/nova-recipe_12345.yaml" + mock_temp_file.return_value = mock_temp_file_instance + + # Create mock recipe + recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Mock the recipe resolution + mock_try_resolve.side_effect = [recipe, None, None] + + # Call the _recipe_resolve_and_save method + result = pytorch._recipe_resolve_and_save(recipe, "nova-recipe", ".") + + # Verify the recipe was resolved and saved + mock_try_resolve.assert_called_with(recipe) + mock_save.assert_called_with(config=recipe, f=mock_temp_file_instance.name) + + # Verify the result is the resolved recipe + assert result == recipe + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("sagemaker.pytorch.estimator.Framework.fit") +def test_fit_with_nova_recipe_s3_upload(mock_framework_fit, mock_recipe_load, sagemaker_session): + """Test that fit correctly uploads the recipe to S3 and adds it to the inputs.""" + # Create a mock recipe that will be identified as a Nova recipe + mock_recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar", "model_name_or_path": "foobar/foobar123"}} + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("nova_recipe", mock_recipe) + + # Create a PyTorch estimator with a Nova recipe + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="nova_recipe", + ) + + # Set Nova recipe attributes + pytorch.is_nova_recipe = True + pytorch.training_recipe_file = temp_file + + # Mock the _upload_recipe_to_s3 method + with patch.object(pytorch, "_upload_recipe_to_s3") as mock_upload_recipe: + mock_upload_recipe.return_value = "s3://mybucket/recipes/nova-recipe.yaml" + + # Call the fit method + pytorch.fit() + + # Verify the upload_recipe_to_s3 method was called + mock_upload_recipe.assert_called_once_with(sagemaker_session, temp_file.name) + + # Verify the fit method was called with the recipe channel + call_args = mock_framework_fit.call_args[1] + assert "inputs" in call_args + assert "recipe" in call_args["inputs"] + + # Verify the hyperparameters were updated with the recipe path + assert "sagemaker_recipe_local_path" in pytorch._hyperparameters + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("sagemaker.pytorch.estimator.PyTorch._upload_recipe_to_s3") +@patch("sagemaker.pytorch.estimator.Framework.fit") +def test_fit_with_nova_recipe_and_inputs( + mock_framework_fit, mock_upload_recipe, mock_recipe_load, sagemaker_session +): + """Test that fit correctly handles Nova recipes with additional inputs.""" + # Create a mock recipe that will be identified as a Nova recipe + mock_recipe = OmegaConf.create( + {"run": {"model_type": "amazon.nova.foobar3", "model_name_or_path": "foobar/foobar-3-8b"}} + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("nova_recipe", mock_recipe) + mock_upload_recipe.return_value = "s3://mybucket/recipes/nova-recipe.yaml" + + # Create a PyTorch estimator with a Nova recipe + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="nova_recipe", + ) + + # Set Nova recipe attributes + pytorch.is_nova_recipe = True + pytorch.training_recipe_file = temp_file + + # Create training inputs + train_input = TrainingInput(s3_data="s3://mybucket/train") + val_input = TrainingInput(s3_data="s3://mybucket/validation") + inputs = {"train": train_input, "validation": val_input} + + # Call the fit method with inputs + pytorch.fit(inputs=inputs) + + # Verify the fit method was called with both the recipe channel and the provided inputs + call_args = mock_framework_fit.call_args[1] + assert "inputs" in call_args + assert "recipe" in call_args["inputs"] + assert "train" in call_args["inputs"] + assert "validation" in call_args["inputs"] + + # Verify the hyperparameters were updated with the recipe path + assert "sagemaker_recipe_local_path" in pytorch._hyperparameters + + +def test_device_get_distribution(): + """Test that _device_get_distribution returns the correct distribution configuration.""" + # Test with GPU device type + gpu_distribution = _device_get_distribution("gpu") + expected_gpu_distribution = { + "torch_distributed": {"enabled": True}, + "smdistributed": { + "modelparallel": { + "enabled": True, + "parameters": { + "placement_strategy": "cluster", + }, + }, + }, + } + assert gpu_distribution == expected_gpu_distribution + + # Test with Trainium device type + trainium_distribution = _device_get_distribution("trainium") + expected_trainium_distribution = { + "torch_distributed": {"enabled": True}, + } + assert trainium_distribution == expected_trainium_distribution + + # Test with CPU device type + cpu_distribution = _device_get_distribution("cpu") + assert cpu_distribution == {} + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_load") +@patch("sagemaker.pytorch.estimator.PyTorch._upload_recipe_to_s3") +@patch("sagemaker.pytorch.estimator.Framework.fit") +def test_fit_with_nova_recipe( + mock_framework_fit, mock_upload_recipe, mock_recipe_load, sagemaker_session +): + """Test that fit correctly handles Nova recipes.""" + + # Create a mock recipe that will be identified as a Nova recipe + mock_recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foo-bar", + "model_name_or_path": "foo-bar123", + } + } + ) + + # Set up the mock to return a recipe name and the mock recipe + mock_recipe_load.return_value = ("nova_recipe", mock_recipe) + + # Create a PyTorch estimator with a Nova recipe + with tempfile.NamedTemporaryFile(suffix=".yaml") as temp_file: + pytorch = PyTorch( + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + training_recipe="nova_recipe", + ) + + # Set Nova recipe attributes + pytorch.is_nova_recipe = True + pytorch.training_recipe_file = temp_file + + # Mock the upload_recipe_to_s3 method + mock_upload_recipe.return_value = "s3://mybucket/recipes/nova-recipe.yaml" + + # Call the fit method + pytorch.fit() + + # Verify the upload_recipe_to_s3 method was called + mock_upload_recipe.assert_called_once_with(sagemaker_session, temp_file.name) + + # Verify the fit method was called with the recipe channel + call_args = mock_framework_fit.call_args[1] + assert "inputs" in call_args + assert "recipe" in call_args["inputs"] + + # Verify the hyperparameters were updated with the recipe path + assert "sagemaker_recipe_local_path" in pytorch._hyperparameters + + +def test_nova_encode_hyperparameters(): + """Test that _nova_encode_hyperparameters correctly preserves string values and encodes non-string values.""" + # Setup test hyperparameters + hyperparameters = { + "string_param": "string_value", + "int_param": 42, + "float_param": 3.14, + "bool_param": True, + "list_param": [1, 2, 3], + "dict_param": {"key": "value"}, + } + + # Call the method + encoded = EstimatorBase._nova_encode_hyperparameters(hyperparameters) + + # Verify string values are preserved + assert encoded["string_param"] == "string_value" + + # Verify non-string values are JSON-encoded + assert encoded["int_param"] == "42" + assert encoded["float_param"] == "3.14" + assert encoded["bool_param"] == "true" + assert encoded["list_param"] == "[1, 2, 3]" + assert encoded["dict_param"] == '{"key": "value"}' + + +def test_framework_set_hyperparameters_nova(): + """Test that Framework.set_hyperparameters uses _nova_encode_hyperparameters for Nova jobs.""" + # Setup + framework = PyTorch( + entry_point="dummy.py", + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version="1.13.1", + py_version="py3", + image_uri=IMAGE_URI, + ) + + framework.is_nova_job = True + + # Add hyperparameters + framework.set_hyperparameters(string_param="string_value", int_param=42, bool_param=True) + + # Verify string values are preserved and non-string values are encoded + assert framework._hyperparameters["string_param"] == "string_value" + assert framework._hyperparameters["int_param"] == "42" + assert framework._hyperparameters["bool_param"] == "true" + + +def test_framework_set_hyperparameters_non_nova(): + """Test that Framework.set_hyperparameters uses _json_encode_hyperparameters for non-Nova jobs.""" + # Setup + framework = PyTorch( + entry_point="dummy.py", + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version="1.13.1", + py_version="py3", + image_uri=IMAGE_URI, + ) + framework.is_nova_recipe = False + + # Add hyperparameters + framework.set_hyperparameters(string_param="string_value", int_param=42, bool_param=True) + + # Verify all values are JSON-encoded + assert framework._hyperparameters["string_param"] == '"string_value"' + assert framework._hyperparameters["int_param"] == "42" + assert framework._hyperparameters["bool_param"] == "true" + + +def test_framework_hyperparameters_nova(): + """Test that Framework.hyperparameters uses _nova_encode_hyperparameters for Nova jobs.""" + # Setup + framework = PyTorch( + entry_point="dummy.py", + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version="1.13.1", + py_version="py3", + image_uri=IMAGE_URI, + ) + + framework.is_nova_job = True + + # Add hyperparameters directly to _hyperparameters + framework._hyperparameters = { + "string_param": "string_value", + "int_param": 42, + "bool_param": True, + } + + # Get hyperparameters + hyperparams = framework.hyperparameters() + + # Verify string values are preserved and non-string values are encoded + assert hyperparams["string_param"] == "string_value" + assert hyperparams["int_param"] == "42" + assert hyperparams["bool_param"] == "true" + + +@patch("sagemaker.pytorch.estimator.PyTorch._recipe_resolve_and_save") +def test_setup_for_nova_recipe_with_distillation(mock_resolve_save, sagemaker_session): + """Test that _setup_for_nova_recipe correctly handles distillation configurations.""" + # Create a mock recipe with distillation config + recipe = OmegaConf.create( + { + "run": { + "model_type": "amazon.nova.foobar3", + "model_name_or_path": "foobar/foobar-3-8b", + "replicas": 4, + }, + "training_config": { + "distillation_data": "s3://mybucket/distillation-data", + "kms_key": "alias/my-kms-key", + }, + } + ) + + # Setup the expected return value + expected_args = { + "hyperparameters": { + "base_model": "foobar/foobar-3-8b", + "distillation_data": "s3://mybucket/distillation-data", + "role_arn": "arn:aws:iam::123456789012:role/SageMakerRole", + "kms_key": "alias/my-kms-key", + }, + "entry_point": None, + "source_dir": None, + "distribution": {}, + "default_image_uri": IMAGE_URI, + } + + with patch( + "sagemaker.pytorch.estimator.PyTorch._setup_for_nova_recipe", return_value=expected_args + ) as mock_nova_setup: + with patch( + "sagemaker.pytorch.estimator.PyTorch._recipe_load", return_value=("nova_recipe", recipe) + ): + mock_resolve_save.return_value = recipe + + pytorch = PyTorch( + training_recipe="nova_recipe", + role="arn:aws:iam::123456789012:role/SageMakerRole", + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE_GPU, + image_uri=IMAGE_URI, + framework_version="1.13.1", + py_version="py3", + ) + + # Check that the Nova recipe was correctly identified + assert pytorch.is_nova_recipe is True + + # Verify _setup_for_nova_recipe was called + mock_nova_setup.assert_called_once() + + # Verify that hyperparameters were set correctly for distillation + assert ( + pytorch._hyperparameters.get("distillation_data") + == "s3://mybucket/distillation-data" + ) + assert pytorch._hyperparameters.get("kms_key") == "alias/my-kms-key" + assert ( + pytorch._hyperparameters.get("role_arn") + == "arn:aws:iam::123456789012:role/SageMakerRole" + ) diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index d9c4129cf6..27ab48d025 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -95,7 +95,7 @@ def _rl_estimator( framework=RLFramework.MXNET, instance_type=None, base_job_name=None, - **kwargs + **kwargs, ): return RLEstimator( entry_point=SCRIPT_PATH, @@ -107,7 +107,7 @@ def _rl_estimator( instance_count=INSTANCE_COUNT, instance_type=instance_type or INSTANCE_TYPE, base_job_name=base_job_name, - **kwargs + **kwargs, ) @@ -335,7 +335,7 @@ def test_rl(time, strftime, sagemaker_session, coach_mxnet_version): sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( RLToolkit.COACH.value, coach_mxnet_version, RLFramework.MXNET.value diff --git a/tests/unit/test_s3.py b/tests/unit/test_s3.py index a226954986..b54552cacb 100644 --- a/tests/unit/test_s3.py +++ b/tests/unit/test_s3.py @@ -17,6 +17,7 @@ from mock import Mock from sagemaker import s3 +from sagemaker.s3_utils import is_s3_url BUCKET_NAME = "mybucket" REGION = "us-west-2" @@ -132,6 +133,34 @@ def test_parse_s3_url_fail(): assert "Expecting 's3' scheme" in str(error) +@pytest.mark.parametrize( + "input_url", + [ + ("s3://bucket/code_location"), + ("s3://bucket/code_location/sub_location"), + ("s3://bucket/code_location/sub_location/"), + ("s3://bucket/"), + ("s3://bucket"), + ], +) +def test_is_s3_url_true(input_url): + assert is_s3_url(input_url) is True + + +@pytest.mark.parametrize( + "input_url", + [ + ("bucket/code_location"), + ("bucket/code_location/sub_location"), + ("sub_location/"), + ("s3/bucket/"), + ("t3://bucket"), + ], +) +def test_is_s3_url_false(input_url): + assert is_s3_url(input_url) is False + + @pytest.mark.parametrize( "expected_output, input_args", [ diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 944f22acff..e3d763e612 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -24,6 +24,8 @@ from botocore.exceptions import ClientError from mock import ANY, MagicMock, Mock, patch, call, mock_open +from sagemaker.model_card.schema_constraints import ModelCardStatusEnum + from .common import _raise_unexpected_client_error import sagemaker from sagemaker import TrainingInput, Session, get_execution_role, exceptions @@ -43,8 +45,6 @@ from sagemaker.utils import update_list_of_dicts_with_values_from_config from sagemaker.user_agent import ( SDK_PREFIX, - STUDIO_PREFIX, - NOTEBOOK_PREFIX, ) from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit import ( @@ -87,15 +87,20 @@ limits={}, ) +SDK_DEFAULT_SUFFIX = f"lib/{SDK_PREFIX}#2.218.0" +NOTEBOOK_SUFFIX = f"{SDK_DEFAULT_SUFFIX} md/AWS-SageMaker-Notebook-Instance#instance_type" +STUDIO_SUFFIX = f"{SDK_DEFAULT_SUFFIX} md/AWS-SageMaker-Studio#app_type" -@pytest.fixture() -def boto_session(): - boto_mock = Mock(name="boto_session", region_name=REGION) +@pytest.fixture +def boto_session(request): + boto_user_agent = "Boto3/1.33.9 md/Botocore#1.33.9 ua/2.0 os/linux#linux-ver md/arch#x86_64 lang/python#3.10.6" + user_agent_suffix = getattr(request, "param", "") + boto_mock = Mock(name="boto_session", region_name=REGION) client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" - ) + user_agent = f"{boto_user_agent} {SDK_DEFAULT_SUFFIX} {user_agent_suffix}" + with patch("sagemaker.user_agent.get_user_agent_extra_suffix", return_value=user_agent_suffix): + client_mock._client_config.user_agent = user_agent boto_mock.client.return_value = client_mock return boto_mock @@ -887,65 +892,42 @@ def test_delete_model(boto_session): boto_session.client().delete_model.assert_called_with(ModelName=model_name) +@pytest.mark.parametrize("boto_session", [""], indirect=True) def test_user_agent_injected(boto_session): - assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent - sess = Session(boto_session) - + expected_user_agent_suffix = "lib/AWS-SageMaker-Python-SDK#2.218.0" for client in [ sess.sagemaker_client, sess.sagemaker_runtime_client, sess.sagemaker_metrics_client, ]: - assert SDK_PREFIX in client._client_config.user_agent - assert NOTEBOOK_PREFIX not in client._client_config.user_agent - assert STUDIO_PREFIX not in client._client_config.user_agent - + assert expected_user_agent_suffix in client._client_config.user_agent -@patch("sagemaker.user_agent.process_notebook_metadata_file", return_value="ml.t3.medium") -def test_user_agent_injected_with_nbi( - mock_process_notebook_metadata_file, - boto_session, -): - assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent - sess = Session( - boto_session=boto_session, +@pytest.mark.parametrize("boto_session", [f"{NOTEBOOK_SUFFIX}"], indirect=True) +def test_user_agent_with_notebook_instance_type(boto_session): + sess = Session(boto_session) + expected_user_agent_suffix = ( + "lib/AWS-SageMaker-Python-SDK#2.218.0 md/AWS-SageMaker-Notebook-Instance#instance_type" ) - for client in [ sess.sagemaker_client, sess.sagemaker_runtime_client, sess.sagemaker_metrics_client, ]: - mock_process_notebook_metadata_file.assert_called() + assert expected_user_agent_suffix in client._client_config.user_agent - assert SDK_PREFIX in client._client_config.user_agent - assert NOTEBOOK_PREFIX in client._client_config.user_agent - assert STUDIO_PREFIX not in client._client_config.user_agent - - -@patch("sagemaker.user_agent.process_studio_metadata_file", return_value="dymmy-app-type") -def test_user_agent_injected_with_studio_app_type( - mock_process_studio_metadata_file, - boto_session, -): - assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent - - sess = Session( - boto_session=boto_session, - ) +@pytest.mark.parametrize("boto_session", [f"{STUDIO_SUFFIX}"], indirect=True) +def test_user_agent_with_studio_app_type(boto_session): + sess = Session(boto_session) + expected_user_agent = "lib/AWS-SageMaker-Python-SDK#2.218.0 md/AWS-SageMaker-Studio#app_type" for client in [ sess.sagemaker_client, sess.sagemaker_runtime_client, sess.sagemaker_metrics_client, ]: - mock_process_studio_metadata_file.assert_called() - - assert SDK_PREFIX in client._client_config.user_agent - assert NOTEBOOK_PREFIX not in client._client_config.user_agent - assert STUDIO_PREFIX in client._client_config.user_agent + assert expected_user_agent in client._client_config.user_agent def test_training_input_all_defaults(): @@ -5024,6 +5006,7 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session) domain = "COMPUTER_VISION" task = "IMAGE_CLASSIFICATION" sample_payload_url = "s3://test-bucket/model" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5112,6 +5095,8 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec skip_model_validation = "All" source_uri = "dummy-source-uri" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + created_versioned_mp_arn = ( "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" ) @@ -5167,6 +5152,7 @@ def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp approval_status = ("Approved",) skip_model_validation = "All" source_uri = "dummy-source-uri" + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} with pytest.raises( ValueError, @@ -5239,6 +5225,8 @@ def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemake return_value={"ModelPackageArn": created_versioned_mp_arn} ) + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + sagemaker_session.create_model_package_from_containers( model_package_group_name=model_package_group_name, containers=containers, @@ -5363,6 +5351,26 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): domain = "COMPUTER_VISION" task = "IMAGE_CLASSIFICATION" sample_payload_url = "s3://test-bucket/model" + model_card = { + "ModelCardStatus": ModelCardStatusEnum.DRAFT, + "Content": { + "model_overview": { + "model_creator": "TestCreator", + }, + "intended_uses": { + "purpose_of_model": "Test model card.", + "intended_uses": "Not used except this test.", + "factors_affecting_model_efficiency": "No.", + "risk_rating": "Low", + "explanations_for_risk_rating": "Just an example.", + }, + }, + } + model_life_cycle = { + "Stage": "Development", + "StageStatus": "In-Progress", + "StageDescription": "Sending for Staging Verification", + } sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5381,6 +5389,8 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): sample_payload_url=sample_payload_url, task=task, skip_model_validation=skip_model_validation, + model_card=model_card, + model_life_cycle=model_life_cycle, ) expected_args = { "ModelPackageName": model_package_name, @@ -5402,6 +5412,8 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "SamplePayloadUrl": sample_payload_url, "Task": task, "SkipModelValidation": skip_model_validation, + "ModelCard": model_card, + "ModelLifeCycle": model_life_cycle, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) @@ -5437,6 +5449,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s approval_status = ("Approved",) description = "description" customer_metadata_properties = {"key1": "value1"} + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -5504,6 +5517,7 @@ def test_create_model_package_from_containers_with_one_instance_types( approval_status = ("Approved",) description = "description" customer_metadata_properties = {"key1": "value1"} + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -6283,6 +6297,24 @@ def test_create_inference_recommendations_job_propogate_other_exception( assert "AccessDeniedException" in str(error) +def test_create_presigned_mlflow_tracking_server_url(sagemaker_session): + sagemaker_session.create_presigned_mlflow_tracking_server_url("ts", 1, 2) + assert ( + sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.called_with( + TrackingServerName="ts", ExpiresInSeconds=1, SessionExpirationDurationInSeconds=2 + ) + ) + + +def test_create_presigned_mlflow_tracking_server_url_minimal(sagemaker_session): + sagemaker_session.create_presigned_mlflow_tracking_server_url("ts") + assert ( + sagemaker_session.sagemaker_client.create_presigned_mlflow_tracking_server_url.called_with( + TrackingServerName="ts" + ) + ) + + DEFAULT_LOG_EVENTS_INFERENCE_RECOMMENDER = [ MockBotoException("ResourceNotFoundException"), {"nextForwardToken": None, "events": [{"timestamp": 1, "message": "hi there #1"}]}, @@ -6992,3 +7024,264 @@ def test_download_data_with_file_and_directory(makedirs, sagemaker_session): Filename="./foo/bar/mode.tar.gz", ExtraArgs=None, ) + + +def test_create_hub(sagemaker_session): + sagemaker_session.create_hub( + hub_name="mock-hub-name", + hub_description="this is my sagemaker hub", + hub_display_name="Mock Hub", + hub_search_keywords=["mock", "hub", "123"], + s3_storage_config={"S3OutputPath": "s3://my-hub-bucket/"}, + tags=[{"Key": "tag-key-1", "Value": "tag-value-1"}], + ) + + request = { + "HubName": "mock-hub-name", + "HubDescription": "this is my sagemaker hub", + "HubDisplayName": "Mock Hub", + "HubSearchKeywords": ["mock", "hub", "123"], + "S3StorageConfig": {"S3OutputPath": "s3://my-hub-bucket/"}, + "Tags": [{"Key": "tag-key-1", "Value": "tag-value-1"}], + } + + sagemaker_session.sagemaker_client.create_hub.assert_called_with(**request) + + +def test_describe_hub(sagemaker_session): + sagemaker_session.describe_hub( + hub_name="mock-hub-name", + ) + + request = { + "HubName": "mock-hub-name", + } + + sagemaker_session.sagemaker_client.describe_hub.assert_called_with(**request) + + +def test_list_hubs(sagemaker_session): + sagemaker_session.list_hubs( + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08-2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + name_contains="mock-hub", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08-2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "NameContains": "mock-hub", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hubs.assert_called_with(**request) + + +def test_list_hub_contents(sagemaker_session): + sagemaker_session.list_hub_contents( + hub_name="mock-hub-123", + hub_content_type="MODELREF", + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08/2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + name_contains="mock-hub", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODELREF", + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08/2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "NameContains": "mock-hub", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hub_contents.assert_called_with(**request) + + +def test_list_hub_content_versions(sagemaker_session): + sagemaker_session.list_hub_content_versions( + hub_name="mock-hub-123", + hub_content_type="MODELREF", + hub_content_name="mock-hub-content-1", + min_version="1.0.0", + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08/2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODELREF", + "HubContentName": "mock-hub-content-1", + "MinVersion": "1.0.0", + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08/2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hub_content_versions.assert_called_with(**request) + + +def test_delete_hub(sagemaker_session): + sagemaker_session.delete_hub( + hub_name="mock-hub-123", + ) + + request = { + "HubName": "mock-hub-123", + } + + sagemaker_session.sagemaker_client.delete_hub.assert_called_with(**request) + + +def test_create_hub_content_reference(sagemaker_session): + sagemaker_session.create_hub_content_reference( + hub_name="mock-hub-name", + source_hub_content_arn=( + "arn:aws:sagemaker:us-east-1:" + "123456789123:" + "hub-content/JumpStartHub/" + "model/mock-hub-content-1" + ), + hub_content_name="mock-hub-content-1", + min_version="1.1.1", + ) + + request = { + "HubName": "mock-hub-name", + "SageMakerPublicHubContentArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/mock-hub-content-1", # noqa: E501 + "HubContentName": "mock-hub-content-1", + "MinVersion": "1.1.1", + } + + sagemaker_session.sagemaker_client.create_hub_content_reference.assert_called_with(**request) + + +def test_delete_hub_content_reference(sagemaker_session): + sagemaker_session.delete_hub_content_reference( + hub_name="mock-hub-name", + hub_content_type="ModelReference", + hub_content_name="mock-hub-content-1", + ) + + request = { + "HubName": "mock-hub-name", + "HubContentType": "ModelReference", + "HubContentName": "mock-hub-content-1", + } + + sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request) + + +def test_create_model_package_from_containers_to_create_mpg_if_not_present_without_search( + sagemaker_session, +): + sagemaker_session.sagemaker_client.search.side_effect = Exception() + sagemaker_session.sagemaker_client.search.return_value = {} + sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [ + { + "ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg"}], + "NextToken": "NextToken", + }, + {"ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg-test"}]}, + ] + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", + model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg", + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [ + {"ModelPackageGroupSummaryList": []} + ] + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with( + ModelPackageGroupName="mock-mpg" + ) + + +def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagemaker_session): + # with search api + sagemaker_session.sagemaker_client.search.return_value = { + "Results": [ + { + "ModelPackageGroup": { + "ModelPackageGroupName": "mock-mpg", + "ModelPackageGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package-group/mock-mpg", + } + } + ] + } + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", + model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg", + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called() + sagemaker_session.sagemaker_client.search.return_value = {"Results": []} + sagemaker_session.create_model_package_from_containers( + source_uri="mock-source-uri", model_package_group_name="mock-mpg" + ) + sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with( + ModelPackageGroupName="mock-mpg" + ) + + +def test_get_most_recently_created_approved_model_package(sagemaker_session): + sagemaker_session.sagemaker_client.list_model_packages.side_effect = [ + ( + { + "ModelPackageSummaryList": [], + "NextToken": "NextToken", + } + ), + ( + { + "ModelPackageSummaryList": [ + { + "CreationTime": 1697440162, + "ModelApprovalStatus": "Approved", + "ModelPackageArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package/model-version/3", + "ModelPackageGroupName": "model-version", + "ModelPackageVersion": 3, + }, + ], + } + ), + ] + model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name="mpg" + ) + assert model_package is not None + assert ( + model_package.model_package_arn + == "arn:aws:sagemaker:us-west-2:123456789012:model-package/model-version/3" + ) diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index b0df31fee1..c418be4646 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -332,7 +332,7 @@ def test_sklearn(time, strftime, sagemaker_session, sklearn_version): sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job(sklearn_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index f0325b79e9..b4d21008b5 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -46,7 +46,54 @@ from sagemaker.workflow.parameters import ParameterString, ParameterInteger from src.sagemaker.tuner import InstanceConfig -from .tuner_test_utils import * # noqa: F403 +from .tuner_test_utils import ( + BASE_JOB_NAME, + BUCKET_NAME, + CategoricalParameter, + ContinuousParameter, + DATA_DIR, + EARLY_STOPPING_TYPE, + Estimator, + ESTIMATOR, + ESTIMATOR_NAME, + ESTIMATOR_NAME_TWO, + ESTIMATOR_TWO, + FRAMEWORK_VERSION, + HYPERPARAMETER_RANGES, + HYPERPARAMETER_RANGES_TWO, + IMAGE_NAME, + INPUTS, + INSTANCE_COUNT, + INSTANCE_TYPE, + IntegerParameter, + JOB_NAME, + LIST_TAGS_RESULT, + MAX_JOBS, + MAX_PARALLEL_JOBS, + METRIC_DEFINITIONS, + MODEL_DATA, + MULTI_ALGO_TUNING_JOB_DETAILS, + NUM_COMPONENTS, + OBJECTIVE_METRIC_NAME, + OBJECTIVE_METRIC_NAME_TWO, + OBJECTIVE_TYPE, + PCA, + PY_VERSION, + REGION, + ROLE, + SAGEMAKER_SESSION, + SCRIPT_NAME, + STRATEGY, + TAGS, + TRAINING_JOB_DESCRIPTION, + TRAINING_JOB_NAME, + TUNING_JOB_DETAILS, + WarmStartConfig, + WarmStartTypes, + WARM_START_CONFIG, + ENDPOINT_DESC, + ENDPOINT_CONFIG_DESC, +) @pytest.fixture() diff --git a/tests/unit/test_tuner_visualize.py b/tests/unit/test_tuner_visualize.py new file mode 100644 index 0000000000..8397ae8e25 --- /dev/null +++ b/tests/unit/test_tuner_visualize.py @@ -0,0 +1,307 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests related to amtviz.visualization""" +from __future__ import absolute_import + +import pandas as pd +import pytest +from mock import Mock, patch, MagicMock +import sagemaker +from sagemaker.estimator import Estimator +from sagemaker.session_settings import SessionSettings +from sagemaker.tuner import HyperparameterTuner +from tests.unit.tuner_test_utils import ( + OBJECTIVE_METRIC_NAME, + HYPERPARAMETER_RANGES, + METRIC_DEFINITIONS, +) + +# Visualization specific imports +from sagemaker.amtviz.visualization import visualize_tuning_job, get_job_analytics_data +from tests.unit.tuner_visualize_test_utils import ( + TUNING_JOB_NAMES, + TUNED_PARAMETERS, + OBJECTIVE_NAME, + TRIALS_DF_DATA, + FULL_DF_DATA, + TUNING_JOB_NAME_1, + TUNING_JOB_NAME_2, + TUNING_JOB_RESULT, + TRIALS_DF_COLUMNS, + FULL_DF_COLUMNS, + TRIALS_DF_TRAINING_JOB_NAMES, + TRIALS_DF_TRAINING_JOB_STATUSES, + TRIALS_DF_VALID_F1_VALUES, + FILTERED_TUNING_JOB_DF_DATA, + TUNING_RANGES, +) +import altair as alt + + +def create_sagemaker_session(): + boto_mock = Mock(name="boto_session") + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + config=None, + local_mode=False, + settings=SessionSettings(), + ) + sms.sagemaker_config = {} + return sms + + +@pytest.fixture() +def sagemaker_session(): + return create_sagemaker_session() + + +@pytest.fixture() +def estimator(sagemaker_session): + return Estimator( + "image", + "role", + 1, + "ml.c4.xlarge", + output_path="s3://bucket/prefix", + sagemaker_session=sagemaker_session, + ) + + +@pytest.fixture() +def tuner(estimator): + return HyperparameterTuner( + estimator, OBJECTIVE_METRIC_NAME, HYPERPARAMETER_RANGES, METRIC_DEFINITIONS + ) + + +@pytest.fixture() +def tuner2(estimator): + return HyperparameterTuner( + estimator, OBJECTIVE_METRIC_NAME, HYPERPARAMETER_RANGES, METRIC_DEFINITIONS + ) + + +@pytest.fixture +def mock_visualize_tuning_job(): + with patch("sagemaker.amtviz.visualize_tuning_job") as mock_visualize: + mock_visualize.return_value = "mock_chart" + yield mock_visualize + + +@pytest.fixture +def mock_get_job_analytics_data(): + with patch("sagemaker.amtviz.visualization.get_job_analytics_data") as mock: + mock.return_value = (pd.DataFrame(TRIALS_DF_DATA), TUNED_PARAMETERS, OBJECTIVE_NAME, True) + yield mock + + +@pytest.fixture +def mock_prepare_consolidated_df(): + with patch("sagemaker.amtviz.visualization._prepare_consolidated_df") as mock: + mock.return_value = pd.DataFrame(FULL_DF_DATA) + yield mock + + +# Test graceful handling if the required altair library is not installed +def test_visualize_jobs_altair_not_installed(capsys): + # Mock importlib.import_module to raise ImportError for 'altair' + with patch("importlib.import_module") as mock_import: + mock_import.side_effect = ImportError("No module named 'altair'") + result = HyperparameterTuner.visualize_jobs(TUNING_JOB_NAMES) + assert result is None + captured = capsys.readouterr() + assert "Altair is not installed." in captured.out + assert "pip install altair" in captured.out + + +# Test basic method call if altair is installed +def test_visualize_jobs_altair_installed(mock_visualize_tuning_job): + # Mock successful import of altair + with patch("importlib.import_module"): + result = HyperparameterTuner.visualize_jobs(TUNING_JOB_NAMES) + assert result == "mock_chart" + + +# Test for static method visualize_jobs() +def test_visualize_jobs(mock_visualize_tuning_job): + result = HyperparameterTuner.visualize_jobs(TUNING_JOB_NAMES) + assert result == "mock_chart" + mock_visualize_tuning_job.assert_called_once_with( + TUNING_JOB_NAMES, return_dfs=False, job_metrics=None, trials_only=False, advanced=False + ) + # Vary the parameters and check if they have been passed correctly + result = HyperparameterTuner.visualize_jobs( + [TUNING_JOB_NAME_1], + return_dfs=True, + job_metrics="job_metrics", + trials_only=True, + advanced=True, + ) + mock_visualize_tuning_job.assert_called_with( + [TUNING_JOB_NAME_1], + return_dfs=True, + job_metrics="job_metrics", + trials_only=True, + advanced=True, + ) + + +# Test the instance method visualize_job() on a stubbed tuner object +def test_visualize_job(tuner, mock_visualize_tuning_job): + # With default parameters + result = tuner.visualize_job() + assert result == "mock_chart" + mock_visualize_tuning_job.assert_called_once_with( + tuner, return_dfs=False, job_metrics=None, trials_only=False, advanced=False + ) + # With varying parameters + result = tuner.visualize_job( + return_dfs=True, job_metrics="job_metrics", trials_only=True, advanced=True + ) + assert result == "mock_chart" + mock_visualize_tuning_job.assert_called_with( + tuner, return_dfs=True, job_metrics="job_metrics", trials_only=True, advanced=True + ) + + +# Test the static method visualize_jobs() on multiple stubbed tuner objects +def test_visualize_multiple_jobs(tuner, tuner2, mock_visualize_tuning_job): + result = HyperparameterTuner.visualize_jobs([tuner, tuner2]) + assert result == "mock_chart" + mock_visualize_tuning_job.assert_called_once_with( + [tuner, tuner2], return_dfs=False, job_metrics=None, trials_only=False, advanced=False + ) + # Vary the parameters and check if they have been passed correctly + result = HyperparameterTuner.visualize_jobs( + [[tuner, tuner2]], + return_dfs=True, + job_metrics="job_metrics", + trials_only=True, + advanced=True, + ) + mock_visualize_tuning_job.assert_called_with( + [[tuner, tuner2]], + return_dfs=True, + job_metrics="job_metrics", + trials_only=True, + advanced=True, + ) + + +# Test direct method call for basic chart return type and default render settings +def test_visualize_tuning_job_analytics_data_results_in_altair_chart(mock_get_job_analytics_data): + result = visualize_tuning_job("mock_job") + assert alt.renderers.active == "default" + assert isinstance(result, alt.VConcatChart) + + +# Test the size and structure of the returned dataframes (trials_df and full_df) +def test_visualize_tuning_job_return_dfs(mock_get_job_analytics_data, mock_prepare_consolidated_df): + charts, trials_df, full_df = visualize_tuning_job("mock_job", return_dfs=True) + # Basic assertion for the charts + assert isinstance(charts, alt.VConcatChart) + + # Assertions for trials_df + assert isinstance(trials_df, pd.DataFrame) + assert trials_df.shape == (2, len(TRIALS_DF_COLUMNS)) + assert trials_df.columns.tolist() == TRIALS_DF_COLUMNS + assert trials_df["TrainingJobName"].tolist() == TRIALS_DF_TRAINING_JOB_NAMES + assert trials_df["TrainingJobStatus"].tolist() == TRIALS_DF_TRAINING_JOB_STATUSES + assert trials_df["TuningJobName"].tolist() == TUNING_JOB_NAMES + assert trials_df["valid-f1"].tolist() == TRIALS_DF_VALID_F1_VALUES + + # Assertions for full_df + assert isinstance(full_df, pd.DataFrame) + assert full_df.shape == (2, 16) + assert full_df.columns.tolist() == FULL_DF_COLUMNS + + +# Test the handling of an an empty trials dataframe +@patch("sagemaker.amtviz.visualization.get_job_analytics_data") +def test_visualize_tuning_job_empty_trials(mock_get_job_analytics_data): + mock_get_job_analytics_data.return_value = ( + pd.DataFrame(), # empty dataframe + TUNED_PARAMETERS, + OBJECTIVE_NAME, + True, + ) + charts = visualize_tuning_job("empty_job") + assert charts.empty + + +# Test handling of return_dfs and trials_only parameter +def test_visualize_tuning_job_trials_only(mock_get_job_analytics_data): + # If return_dfs is set to False, then only charts should be returned + result = visualize_tuning_job("mock_job", return_dfs=False, trials_only=True) + assert isinstance(result, alt.VConcatChart) + # Trials_only controls the content of the two returned dataframes (trials_df, full_df) + result, df1, df2 = visualize_tuning_job("mock_job", return_dfs=True, trials_only=True) + assert isinstance(df1, pd.DataFrame) + assert df1.shape == (2, len(TRIALS_DF_COLUMNS)) + assert isinstance(df2, pd.DataFrame) + assert df2.empty + # The combination of return_dfs and trials_only=False is covered in 'test_visualize_tuning_job_return_dfs' + + +# Check if all parameters are correctly passed to the (mocked) create_charts method +@patch("sagemaker.amtviz.visualization.create_charts") +def test_visualize_tuning_job_with_full_df( + mock_create_charts, mock_get_job_analytics_data, mock_prepare_consolidated_df +): + mock_create_charts.return_value = alt.Chart() + visualize_tuning_job("dummy_job") + + # Check the create_charts call arguments + call_args = mock_create_charts.call_args[0] + call_kwargs = mock_create_charts.call_args[1] + assert isinstance(call_args[0], pd.DataFrame) # trials_df + assert isinstance(call_args[1], list) # tuned_parameters + assert isinstance(call_args[2], pd.DataFrame) # full_df + assert isinstance(call_args[3], str) # objective_name + assert call_kwargs.get("minimize_objective") + + # Check the details of the passed arguments + trials_df = call_args[0] + assert trials_df.columns.tolist() == TRIALS_DF_COLUMNS + tuned_parameters = call_args[1] + assert tuned_parameters == TUNED_PARAMETERS + objective_name = call_args[3] + assert objective_name == OBJECTIVE_NAME + full_df = call_args[2] + assert full_df.columns.tolist() == FULL_DF_COLUMNS + + +# Test the dataframe produced by get_job_analytics_data() +@patch("sagemaker.HyperparameterTuningJobAnalytics") +def test_get_job_analytics_data(mock_hyperparameter_tuning_job_analytics): + # Mock sagemaker's describe_hyper_parameter_tuning_job and some internal methods + sagemaker.amtviz.visualization.sm.describe_hyper_parameter_tuning_job = Mock( + return_value=TUNING_JOB_RESULT + ) + sagemaker.amtviz.visualization._get_tuning_job_names_with_parents = Mock( + return_value=[TUNING_JOB_NAME_1, TUNING_JOB_NAME_2] + ) + sagemaker.amtviz.visualization._get_df = Mock( + return_value=pd.DataFrame(FILTERED_TUNING_JOB_DF_DATA) + ) + mock_tuning_job_instance = MagicMock() + mock_hyperparameter_tuning_job_analytics.return_value = mock_tuning_job_instance + mock_tuning_job_instance.tuning_ranges.values.return_value = TUNING_RANGES + + df, tuned_parameters, objective_name, is_minimize = get_job_analytics_data([TUNING_JOB_NAME_1]) + assert df.shape == (4, 12) + assert df.columns.tolist() == TRIALS_DF_COLUMNS + assert tuned_parameters == TUNED_PARAMETERS + assert objective_name == OBJECTIVE_NAME + assert is_minimize is False diff --git a/tests/unit/test_user_agent.py b/tests/unit/test_user_agent.py index c116fef951..fb46988e7b 100644 --- a/tests/unit/test_user_agent.py +++ b/tests/unit/test_user_agent.py @@ -13,20 +13,17 @@ from __future__ import absolute_import import json -from mock import MagicMock, patch, mock_open +from mock import patch, mock_open from sagemaker.user_agent import ( SDK_PREFIX, SDK_VERSION, - PYTHON_VERSION, - OS_NAME_VERSION, NOTEBOOK_PREFIX, STUDIO_PREFIX, process_notebook_metadata_file, process_studio_metadata_file, - determine_prefix, - prepend_user_agent, + get_user_agent_extra_suffix, ) @@ -60,45 +57,18 @@ def test_process_studio_metadata_file_not_exists(tmp_path): assert process_studio_metadata_file() is None -# Test determine_prefix function -def test_determine_prefix_notebook_instance_type(monkeypatch): - monkeypatch.setattr( - "sagemaker.user_agent.process_notebook_metadata_file", lambda: "instance_type" - ) - assert ( - determine_prefix() - == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION} {NOTEBOOK_PREFIX}/instance_type" - ) - - -def test_determine_prefix_studio_app_type(monkeypatch): - monkeypatch.setattr( - "sagemaker.user_agent.process_studio_metadata_file", lambda: "studio_app_type" - ) - assert ( - determine_prefix() - == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION} {STUDIO_PREFIX}/studio_app_type" - ) - - -def test_determine_prefix_no_metadata(monkeypatch): - monkeypatch.setattr("sagemaker.user_agent.process_notebook_metadata_file", lambda: None) - monkeypatch.setattr("sagemaker.user_agent.process_studio_metadata_file", lambda: None) - assert determine_prefix() == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION}" - - -# Test prepend_user_agent function -def test_prepend_user_agent_existing_user_agent(monkeypatch): - client = MagicMock() - client._client_config.user_agent = "existing_user_agent" - monkeypatch.setattr("sagemaker.user_agent.determine_prefix", lambda _: "prefix") - prepend_user_agent(client) - assert client._client_config.user_agent == "prefix existing_user_agent" - - -def test_prepend_user_agent_no_user_agent(monkeypatch): - client = MagicMock() - client._client_config.user_agent = None - monkeypatch.setattr("sagemaker.user_agent.determine_prefix", lambda _: "prefix") - prepend_user_agent(client) - assert client._client_config.user_agent == "prefix" +# Test get_user_agent_extra_suffix function +def test_get_user_agent_extra_suffix(): + assert get_user_agent_extra_suffix() == f"lib/{SDK_PREFIX}#{SDK_VERSION}" + + with patch("sagemaker.user_agent.process_notebook_metadata_file", return_value="instance_type"): + assert ( + get_user_agent_extra_suffix() + == f"lib/{SDK_PREFIX}#{SDK_VERSION} md/{NOTEBOOK_PREFIX}#instance_type" + ) + + with patch("sagemaker.user_agent.process_studio_metadata_file", return_value="studio_type"): + assert ( + get_user_agent_extra_suffix() + == f"lib/{SDK_PREFIX}#{SDK_VERSION} md/{STUDIO_PREFIX}#studio_type" + ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 81d8279e6d..f243bf1635 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -15,6 +15,7 @@ from __future__ import absolute_import import copy +import logging import shutil import tarfile from datetime import datetime @@ -30,11 +31,14 @@ from mock import call, patch, Mock, MagicMock, PropertyMock import sagemaker +from sagemaker.enums import RoutingStrategy from sagemaker.experiments._run_context import _RunContext from sagemaker.session_settings import SessionSettings from sagemaker.utils import ( + camel_case_to_pascal_case, deep_override_dict, flatten_dict, + get_domain_for_region, get_instance_type_family, retry_with_backoff, check_and_get_run_experiment_config, @@ -50,7 +54,14 @@ _is_bad_link, custom_extractall_tarfile, can_model_package_source_uri_autopopulate, + get_instance_rate_per_hour, + extract_instance_rate_per_hour, + _resolve_routing_config, + tag_exists, + _validate_new_tags, + remove_tag_with_key, ) +from src.sagemaker.config.config_utils import _log_sagemaker_config_single_substitution from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -1271,6 +1282,87 @@ def test_resolve_value_from_config(): mock_info_logger.reset_mock() +class TestLogSagemakerConfig(TestCase): + + def test_sensitive_info_masking(self): + logger = logging.getLogger("sagemaker.config") + logger.setLevel(logging.DEBUG) + + stream_handler = logging.StreamHandler() + logger.addHandler(stream_handler) + + # source value is None + with self.assertLogs(logger, level="DEBUG") as log: + _log_sagemaker_config_single_substitution( + None, {"apiKey": "topsecretkey"}, "config/path" + ) + + self.assertIn("config value that will be used = {'apiKey': '***'}", log.output[0]) + + # source value is None and config_value == source_value + with self.assertLogs(logger, level="DEBUG") as log: + _log_sagemaker_config_single_substitution( + {"secretword": "topsecretword"}, {"secretword": "topsecretword"}, "config/path" + ) + + self.assertIn("Skipped value", log.output[0]) + self.assertIn("source value that will be used = {'secretword': '***'}", log.output[0]) + self.assertIn("config value = {'secretword': '***'}", log.output[0]) + + # source value is not None and config_value != source_value + with self.assertLogs(logger, level="DEBUG") as log: + _log_sagemaker_config_single_substitution( + {"password": "supersecretpassword"}, {"apiKey": "topsecretkey"}, "config/path" + ) + + self.assertIn("Skipped value", log.output[0]) + self.assertIn("source value that will be used = {'password': '***'}", log.output[0]) + self.assertIn("config value = {'apiKey': '***'}", log.output[0]) + + def test_non_sensitive_info_masking(self): + logger = logging.getLogger("sagemaker.config") + logger.setLevel(logging.DEBUG) + + stream_handler = logging.StreamHandler() + logger.addHandler(stream_handler) + + # source value is None + with self.assertLogs(logger, level="DEBUG") as log: + _log_sagemaker_config_single_substitution( + None, {"username": "randomvalue"}, "config/path" + ) + + self.assertIn("config value that will be used = {'username': 'randomvalue'}", log.output[0]) + + # source value is not None and config_value == source_value + with self.assertLogs(logger, level="DEBUG") as log: + _log_sagemaker_config_single_substitution( + {"nonsensitivevalue": "randomvalue"}, + {"nonsensitivevalue": "randomvalue"}, + "config/path", + ) + + self.assertIn("Skipped value", log.output[0]) + self.assertIn( + "source value that will be used = {'nonsensitivevalue': 'randomvalue'}", log.output[0] + ) + self.assertIn("config value = {'nonsensitivevalue': 'randomvalue'}", log.output[0]) + + # source value is not None and config_value != source_value + with self.assertLogs(logger, level="DEBUG") as log: + _log_sagemaker_config_single_substitution( + {"username": "nonsensitiveinfo"}, + {"configvalue": "nonsensitivevalue"}, + "config/path/non_sensitive", + ) + + self.assertIn("Skipped value", log.output[0]) + self.assertIn( + "source value that will be used = {'username': 'nonsensitiveinfo'}", log.output[0] + ) + self.assertIn("config value = {'configvalue': 'nonsensitivevalue'}", log.output[0]) + + def test_get_sagemaker_config_value(): mock_config_logger = Mock() @@ -1724,6 +1816,8 @@ def test_volume_size_not_supported(self): "local", "local_gpu", ParameterString(name="InstanceType", default_value="ml.m4.xlarge"), + "ml.trn1.32xlarge", + "ml.trn1n.32xlarge", ] for instance in instances_that_dont_support_volume_size: @@ -1817,7 +1911,13 @@ def test_can_model_package_source_uri_autopopulate(): class TestDeepMergeDict(TestCase): def test_flatten_dict_basic(self): nested_dict = {"a": 1, "b": {"x": 2, "y": {"p": 3, "q": 4}}, "c": 5} - flattened_dict = {"a": 1, "b.x": 2, "b.y.p": 3, "b.y.q": 4, "c": 5} + flattened_dict = { + ("a",): 1, + ("b", "x"): 2, + ("b", "y", "p"): 3, + ("b", "y", "q"): 4, + ("c",): 5, + } self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) @@ -1829,13 +1929,19 @@ def test_flatten_dict_empty(self): def test_flatten_dict_no_nested(self): nested_dict = {"a": 1, "b": 2, "c": 3} - flattened_dict = {"a": 1, "b": 2, "c": 3} + flattened_dict = {("a",): 1, ("b",): 2, ("c",): 3} self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) def test_flatten_dict_with_various_types(self): nested_dict = {"a": [1, 2, 3], "b": {"x": None, "y": {"p": [], "q": ""}}, "c": 9} - flattened_dict = {"a": [1, 2, 3], "b.x": None, "b.y.p": [], "b.y.q": "", "c": 9} + flattened_dict = { + ("a",): [1, 2, 3], + ("b", "x"): None, + ("b", "y", "p"): [], + ("b", "y", "q"): "", + ("c",): 9, + } self.assertDictEqual(flatten_dict(nested_dict), flattened_dict) self.assertDictEqual(unflatten_dict(flattened_dict), nested_dict) @@ -1857,6 +1963,18 @@ def test_deep_override_nested_lists(self): expected_merged = {"a": [5], "b": {"c": [6, 7], "d": [8]}} self.assertDictEqual(deep_override_dict(dict1, dict2), expected_merged) + def test_deep_override_nested_lists_overriding_none(self): + dict1 = {"a": [{"c": "d"}, {"e": "f"}], "t": None} + dict2 = { + "a": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"], + "t": {"g": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"]}, + } + expected_merged = { + "a": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"], + "t": {"g": [{"1": "2"}, {"3": "4"}, {"5": "6"}, "7"]}, + } + self.assertDictEqual(deep_override_dict(dict1, dict2), expected_merged) + def test_deep_override_skip_keys(self): dict1 = {"a": 1, "b": {"x": 2, "y": 3}, "c": [4, 5]} dict2 = { @@ -1866,3 +1984,263 @@ def test_deep_override_skip_keys(self): expected_result = {"a": 1, "b": {"x": 20, "y": 3, "z": 30}, "c": [4, 5]} self.assertEqual(deep_override_dict(dict1, dict2, skip_keys=["c", "d"]), expected_result) + + +@pytest.mark.parametrize( + "instance, region, amazon_sagemaker_price_result, expected", + [ + ( + "ml.t4g.nano", + "us-west-2", + { + "PriceList": [ + { + "terms": { + "OnDemand": { + "3WK7G7WSYVS3K492.JRTCKXETXF": { + "priceDimensions": { + "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": { + "unit": "Hrs", + "endRange": "Inf", + "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour", + "appliesTo": [], + "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7", + "beginRange": "0", + "pricePerUnit": {"USD": "0.9000000000"}, + } + } + } + } + }, + } + ] + }, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.9"}, + ), + ( + "ml.t4g.nano", + "eu-central-1", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.008"}, + ), + ( + "ml.t4g.nano", + "af-south-1", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.008"}, + ), + ( + "ml.t4g.nano", + "ap-northeast-2", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.008"}, + ), + ], +) +@patch("boto3.client") +def test_get_instance_rate_per_hour( + mock_client, instance, region, amazon_sagemaker_price_result, expected +): + + mock_client.return_value.get_products.side_effect = ( + lambda *args, **kwargs: amazon_sagemaker_price_result + ) + instance_rate = get_instance_rate_per_hour(instance_type=instance, region=region) + + assert instance_rate == expected + + +@pytest.mark.parametrize( + "price_data, expected_result", + [ + (None, None), + ( + { + "terms": { + "OnDemand": { + "3WK7G7WSYVS3K492.JRTCKXETXF": { + "priceDimensions": { + "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": { + "unit": "Hrs", + "endRange": "Inf", + "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour", + "appliesTo": [], + "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7", + "beginRange": "0", + "pricePerUnit": {"USD": "0.9000000000"}, + } + } + } + } + } + }, + {"name": "On-demand Instance Rate", "unit": "USD/Hr", "value": "0.9"}, + ), + ], +) +def test_extract_instance_rate_per_hour(price_data, expected_result): + out = extract_instance_rate_per_hour(price_data) + + assert out == expected_result + + +@pytest.mark.parametrize( + "routing_config, expected", + [ + ({"RoutingStrategy": RoutingStrategy.RANDOM}, {"RoutingStrategy": "RANDOM"}), + ({"RoutingStrategy": "RANDOM"}, {"RoutingStrategy": "RANDOM"}), + ( + {"RoutingStrategy": RoutingStrategy.LEAST_OUTSTANDING_REQUESTS}, + {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"}, + ), + ( + {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"}, + {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"}, + ), + ({"RoutingStrategy": None}, None), + (None, None), + ], +) +def test_resolve_routing_config(routing_config, expected): + res = _resolve_routing_config(routing_config) + + assert res == expected + + +def test_resolve_routing_config_ex(): + pytest.raises(ValueError, lambda: _resolve_routing_config({"RoutingStrategy": "Invalid"})) + + +class TestConvertToPascalCase(TestCase): + def test_simple_dict(self): + input_dict = {"first_name": "John", "last_name": "Doe"} + expected_output = {"FirstName": "John", "LastName": "Doe"} + self.assertEqual(camel_case_to_pascal_case(input_dict), expected_output) + + def camel_case_to_pascal_case_nested(self): + input_dict = { + "model_name": "my-model", + "primary_container": { + "image": "my-docker-image:latest", + "model_data_url": "s3://my-bucket/model.tar.gz", + "environment": {"env_var_1": "value1", "env_var_2": "value2"}, + }, + "execution_role_arn": "arn:aws:iam::123456789012:role/my-sagemaker-role", + "tags": [ + {"key": "project", "value": "my-project"}, + {"key": "environment", "value": "development"}, + ], + } + expected_output = { + "ModelName": "my-model", + "PrimaryContainer": { + "Image": "my-docker-image:latest", + "ModelDataUrl": "s3://my-bucket/model.tar.gz", + "Environment": {"EnvVar1": "value1", "EnvVar2": "value2"}, + }, + "ExecutionRoleArn": "arn:aws:iam::123456789012:role/my-sagemaker-role", + "Tags": [ + {"Key": "project", "Value": "my-project"}, + {"Key": "environment", "Value": "development"}, + ], + } + self.assertEqual(camel_case_to_pascal_case(input_dict), expected_output) + + def test_empty_input(self): + self.assertEqual(camel_case_to_pascal_case({}), {}) + + +class TestTags(TestCase): + def test_tag_exists(self): + curr_tags = [{"Key": "project", "Value": "my-project"}] + self.assertTrue(tag_exists({"Key": "project", "Value": "my-project"}, curr_tags=curr_tags)) + + def test_does_not_tag_exists(self): + curr_tags = [{"Key": "project", "Value": "my-project"}] + self.assertFalse( + tag_exists({"Key": "project-2", "Value": "my-project-2"}, curr_tags=curr_tags) + ) + + def test_add_tags(self): + curr_tags = [{"Key": "project", "Value": "my-project"}] + new_tag = {"Key": "project-2", "Value": "my-project-2"} + expected = [ + {"Key": "project", "Value": "my-project"}, + {"Key": "project-2", "Value": "my-project-2"}, + ] + + self.assertEqual(_validate_new_tags(new_tag, curr_tags), expected) + + def test_new_add_tags(self): + new_tag = {"Key": "project-2", "Value": "my-project-2"} + + self.assertEqual(_validate_new_tags(new_tag, None), new_tag) + + def test_remove_existing_tag(self): + original_tags = [ + {"Key": "Tag1", "Value": "Value1"}, + {"Key": "Tag2", "Value": "Value2"}, + {"Key": "Tag3", "Value": "Value3"}, + ] + expected_output = [{"Key": "Tag1", "Value": "Value1"}, {"Key": "Tag3", "Value": "Value3"}] + self.assertEqual(remove_tag_with_key("Tag2", original_tags), expected_output) + + def test_remove_non_existent_tag(self): + original_tags = [ + {"Key": "Tag1", "Value": "Value1"}, + {"Key": "Tag2", "Value": "Value2"}, + {"Key": "Tag3", "Value": "Value3"}, + ] + self.assertEqual(remove_tag_with_key("NonExistentTag", original_tags), original_tags) + + def test_remove_only_tag(self): + original_tags = [{"Key": "Tag1", "Value": "Value1"}] + self.assertIsNone(remove_tag_with_key("Tag1", original_tags)) + + +class TestGetDomainForRegion(TestCase): + def test_get_domain_for_region(self): + self.assertEqual(get_domain_for_region("us-west-2"), "amazonaws.com") + self.assertEqual(get_domain_for_region("eu-west-1"), "amazonaws.com") + self.assertEqual(get_domain_for_region("ap-northeast-1"), "amazonaws.com") + self.assertEqual(get_domain_for_region("us-gov-west-1"), "amazonaws.com") + self.assertEqual(get_domain_for_region("cn-northwest-1"), "amazonaws.com.cn") + self.assertEqual(get_domain_for_region("us-iso-east-1"), "c2s.ic.gov") + self.assertEqual(get_domain_for_region("us-isob-east-1"), "sc2s.sgov.gov") + self.assertEqual(get_domain_for_region("invalid-region"), "amazonaws.com") diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 18eab98149..b694e63fe1 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -330,7 +330,7 @@ def test_xgboost_cpu(time, strftime, sagemaker_session, xgboost_framework_versio sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job(xgboost_framework_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs @@ -377,7 +377,7 @@ def test_xgboost_gpu(time, strftime, sagemaker_session, xgboost_gpu_framework_ve sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job( xgboost_gpu_framework_version, instance_type=GPU_INSTANCE_TYPE @@ -427,7 +427,7 @@ def test_distributed_training(time, strftime, sagemaker_session, xgboost_framewo sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] assert sagemaker_call_names == ["train", "logs_for_job"] boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] - assert boto_call_names == ["resource"] + assert "resource" in boto_call_names expected_train_args = _create_train_job(xgboost_framework_version, DIST_INSTANCE_COUNT) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs diff --git a/tests/unit/tuner_visualize_test_utils.py b/tests/unit/tuner_visualize_test_utils.py new file mode 100644 index 0000000000..d9524ff7e6 --- /dev/null +++ b/tests/unit/tuner_visualize_test_utils.py @@ -0,0 +1,159 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +TRIALS_DF_COLUMNS = [ + "criterion", + "max-depth", + "min-samples-leaf", + "min-weight-fraction-leaf", + "n-estimators", + "TrainingJobName", + "TrainingJobStatus", + "TrainingStartTime", + "TrainingEndTime", + "TrainingElapsedTimeSeconds", + "TuningJobName", + "valid-f1", +] + +FULL_DF_COLUMNS = [ + "value", + "ts", + "label", + "rel_ts", + "TrainingJobName", + "criterion", + "max-depth", + "min-samples-leaf", + "min-weight-fraction-leaf", + "n-estimators", + "TrainingJobStatus", + "TrainingStartTime", + "TrainingEndTime", + "TrainingElapsedTimeSeconds", + "TuningJobName", + "valid-f1", +] + + +TRIALS_DF_TRAINING_JOB_NAMES = [ + "random-240712-1545-019-4ac17a84", + "random-240712-1545-021-fcd64dc1", +] + +TRIALS_DF_TRAINING_JOB_STATUSES = ["Completed", "Completed"] + +TUNING_JOB_NAME_1 = "random-240712-1500" +TUNING_JOB_NAME_2 = "bayesian-240712-1600" +TUNING_JOB_NAMES = [TUNING_JOB_NAME_1, TUNING_JOB_NAME_2] +TRIALS_DF_VALID_F1_VALUES = [0.950, 0.896] + +FULL_DF_COLUMNS = [ + "value", + "ts", + "label", + "rel_ts", + "TrainingJobName", + "criterion", + "max-depth", + "min-samples-leaf", + "min-weight-fraction-leaf", + "n-estimators", + "TrainingJobStatus", + "TrainingStartTime", + "TrainingEndTime", + "TrainingElapsedTimeSeconds", + "TuningJobName", + "valid-f1", +] + +TUNED_PARAMETERS = [ + "n-estimators", + "max-depth", + "min-samples-leaf", + "min-weight-fraction-leaf", + "criterion", +] +OBJECTIVE_NAME = "valid-f1" + +TRIALS_DF_DATA = { + "criterion": ["gini", "log_loss"], + "max-depth": [18.0, 8.0], + "min-samples-leaf": [3.0, 10.0], + "min-weight-fraction-leaf": [0.011596, 0.062067], + "n-estimators": [110.0, 18.0], + "TrainingJobName": ["random-240712-1545-019-4ac17a84", "random-240712-1545-021-fcd64dc1"], + "TrainingJobStatus": ["Completed", "Completed"], + "TrainingStartTime": ["2024-07-12 17:55:59+02:00", "2024-07-12 17:56:50+02:00"], + "TrainingEndTime": ["2024-07-12 17:56:43+02:00", "2024-07-12 17:57:29+02:00"], + "TrainingElapsedTimeSeconds": [44.0, 39.0], + "TuningJobName": TUNING_JOB_NAMES, + "valid-f1": [0.950, 0.896], +} + +FULL_DF_DATA = { + "value": [0.951000, 0.950000], + "ts": ["2024-07-12 15:56:00", "2024-07-12 15:56:00"], + "label": ["valid-precision", "valid-recall"], + "rel_ts": ["1970-01-01 01:00:00", "1970-01-01 01:00:00"], + "TrainingJobName": ["random-240712-1545-019-4ac17a84", "random-240712-1545-019-4ac17a84"], + "criterion": ["gini", "gini"], + "max-depth": [18.0, 18.0], + "min-samples-leaf": [3.0, 3.0], + "min-weight-fraction-leaf": [0.011596, 0.011596], + "n-estimators": [110.0, 110.0], + "TrainingJobStatus": ["Completed", "Completed"], + "TrainingStartTime": ["2024-07-12 17:55:59+02:00", "2024-07-12 17:55:59+02:00"], + "TrainingEndTime": ["2024-07-12 17:56:43+02:00", "2024-07-12 17:56:43+02:00"], + "TrainingElapsedTimeSeconds": [44.0, 45.0], + "TuningJobName": ["random-240712-1545", "random-240712-1545"], + "valid-f1": [0.9500, 0.9500], +} + +FILTERED_TUNING_JOB_DF_DATA = { + "criterion": ["log_loss", "gini"], + "max-depth": [10.0, 16.0], + "min-samples-leaf": [7.0, 2.0], + "min-weight-fraction-leaf": [0.160910, 0.069803], + "n-estimators": [67.0, 79.0], + "TrainingJobName": ["random-240712-1545-050-c0b5c10a", "random-240712-1545-049-2db2ec05"], + "TrainingJobStatus": ["Completed", "Completed"], + "FinalObjectiveValue": [0.8190, 0.8910], + "TrainingStartTime": ["2024-07-12 18:09:48+02:00", "2024-07-12 18:09:45+02:00"], + "TrainingEndTime": ["2024-07-12 18:10:28+02:00", "2024-07-12 18:10:23+02:00"], + "TrainingElapsedTimeSeconds": [40.0, 38.0], + "TuningJobName": [TUNING_JOB_NAME_1, TUNING_JOB_NAME_2], +} + +TUNING_RANGES = [ + {"Name": "n-estimators", "MinValue": "1", "MaxValue": "200", "ScalingType": "Auto"}, + {"Name": "max-depth", "MinValue": "1", "MaxValue": "20", "ScalingType": "Auto"}, + {"Name": "min-samples-leaf", "MinValue": "1", "MaxValue": "10", "ScalingType": "Auto"}, + { + "Name": "min-weight-fraction-leaf", + "MinValue": "0.01", + "MaxValue": "0.5", + "ScalingType": "Auto", + }, + {"Name": "criterion", "Values": ['"gini"', '"entropy"', '"log_loss"']}, +] + +TUNING_JOB_RESULT = { + "HyperParameterTuningJobName": TUNING_JOB_NAME_1, + "HyperParameterTuningJobConfig": { + "Strategy": "Random", + "HyperParameterTuningJobObjective": {"Type": "Maximize", "MetricName": "valid-f1"}, + }, + "HyperParameterTuningJobStatus": "Completed", +} diff --git a/tox.ini b/tox.ini index 718e968013..9c624b2052 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,8 @@ # and then run "tox" from this directory. [tox] -envlist = black-format,flake8,pylint,docstyle,sphinx,doc8,twine,py38,py39,py310,py311 +isolated_build = true +envlist = black-format,flake8,pylint,docstyle,sphinx,doc8,twine,py39,py310,py311,py312 skip_missing_interpreters = False @@ -20,13 +21,13 @@ exclude = tests/data/ venv/ env/ - tests/unit/test_tensorboard.py # excluding this file for time being + tests/unit/test_tensorboard.py max-complexity = 10 ignore = C901, - E203, # whitespace before ':': Black disagrees with and explicitly violates this. + E203, FI10, FI12, FI13, @@ -34,7 +35,7 @@ ignore = FI15, FI16, FI17, - FI18, # __future__ import "annotations" missing -> check only Python 3.7 compatible + FI18, FI50, FI51, FI52, @@ -66,7 +67,9 @@ markers = [testenv] setenv = PYTHONHASHSEED=42 -pip_version = pip==21.3 +pip_version = pip==24.3 +allowlist_externals = + aws passenv = AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY @@ -81,20 +84,30 @@ passenv = # Can be used to specify which tests to run, e.g.: tox -- -s commands = python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')" - pip install 'apache-airflow==2.9.0' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.9.0/constraints-3.8.txt" - pip install 'torch==2.0.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' - pip install 'torchvision==0.15.2+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' - pip install 'dill>=0.3.8' + + pip install 'apache-airflow==2.10.4' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.10.4/constraints-3.9.txt" + pip install 'torch==2.3.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' + pip install 'torchvision==0.18.1+cpu' -f 'https://download.pytorch.org/whl/torch_stable.html' + pip install 'dill>=0.3.9' + pip install 'altair>=5.3' # needed for amtviz + pip install -U "sagemaker-core" # needed to keep sagemaker-core up to date pytest {posargs} -deps = .[test] +deps = + .[test] + asyncio + nest_asyncio + pytest-asyncio depends = - {py38,py39,py310,p311}: clean + {py39,py310,py311,py312}: clean + +[testenv:py312] +basepython = python3.12 [testenv:runcoverage] -description = run unit tests with coverage +description = run unit tests with coverage commands = - pytest --cov=sagemaker --cov-append {posargs} + pytest --cov=sagemaker --cov-append {posargs} {env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 [testenv:flake8] @@ -104,6 +117,7 @@ deps = -r requirements/tox/flake8_requirements.txt commands = flake8 +basepython = python3.12 [testenv:pylint] skipdist = true @@ -111,7 +125,7 @@ skip_install = true deps = -r requirements/tox/pylint_requirements.txt commands = - python -m pylint --rcfile=.pylintrc -j 0 src/sagemaker + python -m pylint --rcfile=.pylintrc -j 0 src/sagemaker --fail-under=9.9 [testenv:spelling] skipdist = true @@ -127,18 +141,18 @@ skip_install = true deps = -r requirements/tox/twine_requirements.txt commands = - python setup.py sdist + python -m build --sdist twine check dist/*.tar.gz [testenv:sphinx] -pip_version = pip==21.3 +pip_version = pip==24.3 changedir = doc # pip install requirements.txt is separate as RTD does it in separate steps # having the requirements.txt installed in deps above results in Double Requirement exception # https://github.com/pypa/pip/issues/988 commands = pip install --exists-action=w -r requirements.txt - sphinx-build -T -W -b html -d _build/doctrees-readthedocs -D language=en . _build/html + sphinx-build -T -b html -d _build/doctrees-readthedocs -D language=en . _build/html [testenv:doc8] deps =