Skip to content

Commit fa79a5d

Browse files
committed
feat: jumpstart retrieve functions (wip)
1 parent b09793a commit fa79a5d

File tree

13 files changed

+1011
-125
lines changed

13 files changed

+1011
-125
lines changed

src/sagemaker/image_uris.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from sagemaker import utils
2222
from sagemaker.spark import defaults
23+
from sagemaker.jumpstart import accessors as jumpstart_accessors
2324

2425
logger = logging.getLogger(__name__)
2526

@@ -39,6 +40,8 @@ def retrieve(
3940
distribution=None,
4041
base_framework_version=None,
4142
training_compiler_config=None,
43+
model_id=None,
44+
model_version=None,
4245
):
4346
"""Retrieves the ECR URI for the Docker image matching the given arguments.
4447
@@ -69,13 +72,56 @@ def retrieve(
6972
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7073
A configuration class for the SageMaker Training Compiler
7174
(default: None).
75+
model_id (str): JumpStart model id for which to retrieve image URI.
76+
model_version (str): JumpStart model version for which to retrieve image URI.
7277
7378
Returns:
7479
str: the ECR URI for the corresponding SageMaker Docker image.
7580
7681
Raises:
7782
ValueError: If the combination of arguments specified is not supported.
7883
"""
84+
if model_id is not None or model_version is not None:
85+
if model_id is None or model_version is None:
86+
raise ValueError(
87+
"Must specify `model_id` and `model_version` when getting image uri for "
88+
"JumpStart models. "
89+
)
90+
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
91+
region, model_id, model_version
92+
)
93+
if image_scope is None:
94+
raise ValueError(
95+
"Must specify `image_scope` argument to retrieve image uri for " "JumpStart models."
96+
)
97+
if image_scope == "inference":
98+
ecr_specs = model_specs.hosting_ecr_specs
99+
elif image_scope == "training":
100+
if not model_specs.training_supported:
101+
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
102+
ecr_specs = model_specs.training_ecr_specs
103+
else:
104+
raise ValueError("JumpStart models only support inference and training.")
105+
106+
if framework != None and framework != ecr_specs.framework:
107+
raise ValueError(
108+
f"Bad value for container framework for JumpStart model: '{framework}'."
109+
)
110+
111+
return retrieve(
112+
framework=ecr_specs.framework,
113+
region=region,
114+
version=ecr_specs.framework_version,
115+
py_version=ecr_specs.py_version,
116+
instance_type=instance_type,
117+
accelerator_type=accelerator_type,
118+
image_scope=image_scope,
119+
container_version=container_version,
120+
distribution=distribution,
121+
base_framework_version=base_framework_version,
122+
training_compiler_config=training_compiler_config,
123+
)
124+
79125
if training_compiler_config is None:
80126
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
81127
elif framework == HUGGING_FACE_FRAMEWORK:
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains accessors related to SageMaker JumpStart."""
14+
from __future__ import absolute_import
15+
from typing import Any, Dict, Optional
16+
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
17+
from sagemaker.jumpstart import cache
18+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
19+
20+
21+
class SageMakerSettings(object):
22+
"""Static class for storing the SageMaker settings."""
23+
24+
_PARSED_SAGEMAKER_VERSION = ""
25+
26+
@staticmethod
27+
def set_sagemaker_version(version: str) -> None:
28+
"""Set SageMaker version."""
29+
SageMakerSettings._PARSED_SAGEMAKER_VERSION = version
30+
31+
@staticmethod
32+
def get_sagemaker_version() -> str:
33+
"""Return SageMaker version."""
34+
return SageMakerSettings._PARSED_SAGEMAKER_VERSION
35+
36+
37+
class JumpStartModelsCache(object):
38+
"""Static class for storing the JumpStart models cache."""
39+
40+
_cache: Optional[cache.JumpStartModelsCache] = None
41+
_curr_region = JUMPSTART_DEFAULT_REGION_NAME
42+
43+
_cache_kwargs = {}
44+
45+
def _validate_region_cache_kwargs(
46+
cache_kwargs: Dict[str, Any] = {}, region: Optional[str] = None
47+
):
48+
if region is not None and "region" in cache_kwargs:
49+
if region != cache_kwargs["region"]:
50+
raise ValueError(
51+
f"Inconsistent region definitions: {region}, {cache_kwargs['region']}"
52+
)
53+
del cache_kwargs["region"]
54+
return cache_kwargs
55+
56+
@staticmethod
57+
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
58+
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(
59+
JumpStartModelsCache._cache_kwargs, region
60+
)
61+
if JumpStartModelsCache._cache == None or region != JumpStartModelsCache._curr_region:
62+
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
63+
JumpStartModelsCache._curr_region = region
64+
return JumpStartModelsCache._cache.get_header(model_id, version)
65+
66+
@staticmethod
67+
def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs:
68+
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(
69+
JumpStartModelsCache._cache_kwargs, region
70+
)
71+
if JumpStartModelsCache._cache == None or region != JumpStartModelsCache._curr_region:
72+
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
73+
JumpStartModelsCache._curr_region = region
74+
return JumpStartModelsCache._cache.get_specs(model_id, version)
75+
76+
@staticmethod
77+
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
78+
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(cache_kwargs, region)
79+
JumpStartModelsCache._cache_kwargs = cache_kwargs
80+
if region is None:
81+
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
82+
**JumpStartModelsCache._cache_kwargs
83+
)
84+
else:
85+
JumpStartModelsCache._curr_region = region
86+
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
87+
region=region, **JumpStartModelsCache._cache_kwargs
88+
)
89+
90+
@staticmethod
91+
def reset_cache(cache_kwargs: Dict[str, Any] = {}, region: str = None) -> None:
92+
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(cache_kwargs, region)
93+
JumpStartModelsCache._cache_kwargs = cache_kwargs
94+
if region is None:
95+
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
96+
**JumpStartModelsCache._cache_kwargs
97+
)
98+
else:
99+
JumpStartModelsCache._curr_region = region
100+
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
101+
region=region, **JumpStartModelsCache._cache_kwargs
102+
)

src/sagemaker/jumpstart/constants.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,98 @@
1717
from sagemaker.jumpstart.types import JumpStartLaunchedRegionInfo
1818

1919

20-
JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set()
20+
JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set(
21+
[
22+
JumpStartLaunchedRegionInfo(
23+
region_name="us-west-2",
24+
content_bucket="jumpstart-cache-prod-us-west-2",
25+
),
26+
JumpStartLaunchedRegionInfo(
27+
region_name="us-east-1",
28+
content_bucket="jumpstart-cache-prod-us-east-1",
29+
),
30+
JumpStartLaunchedRegionInfo(
31+
region_name="us-east-2",
32+
content_bucket="jumpstart-cache-prod-us-east-2",
33+
),
34+
JumpStartLaunchedRegionInfo(
35+
region_name="eu-west-1",
36+
content_bucket="jumpstart-cache-prod-eu-west-1",
37+
),
38+
JumpStartLaunchedRegionInfo(
39+
region_name="eu-central-1",
40+
content_bucket="jumpstart-cache-prod-eu-central-1",
41+
),
42+
JumpStartLaunchedRegionInfo(
43+
region_name="eu-north-1",
44+
content_bucket="jumpstart-cache-prod-eu-north-1",
45+
),
46+
JumpStartLaunchedRegionInfo(
47+
region_name="me-south-1",
48+
content_bucket="jumpstart-cache-prod-me-south-1",
49+
),
50+
JumpStartLaunchedRegionInfo(
51+
region_name="ap-south-1",
52+
content_bucket="jumpstart-cache-prod-ap-south-1",
53+
),
54+
JumpStartLaunchedRegionInfo(
55+
region_name="eu-west-3",
56+
content_bucket="jumpstart-cache-prod-eu-west-3",
57+
),
58+
JumpStartLaunchedRegionInfo(
59+
region_name="af-south-1",
60+
content_bucket="jumpstart-cache-prod-af-south-1",
61+
),
62+
JumpStartLaunchedRegionInfo(
63+
region_name="sa-east-1",
64+
content_bucket="jumpstart-cache-prod-sa-east-1",
65+
),
66+
JumpStartLaunchedRegionInfo(
67+
region_name="ap-east-1",
68+
content_bucket="jumpstart-cache-prod-ap-east-1",
69+
),
70+
JumpStartLaunchedRegionInfo(
71+
region_name="ap-northeast-2",
72+
content_bucket="jumpstart-cache-prod-ap-northeast-2",
73+
),
74+
JumpStartLaunchedRegionInfo(
75+
region_name="eu-west-2",
76+
content_bucket="jumpstart-cache-prod-eu-west-2",
77+
),
78+
JumpStartLaunchedRegionInfo(
79+
region_name="eu-south-1",
80+
content_bucket="jumpstart-cache-prod-eu-south-1",
81+
),
82+
JumpStartLaunchedRegionInfo(
83+
region_name="ap-northeast-1",
84+
content_bucket="jumpstart-cache-prod-ap-northeast-1",
85+
),
86+
JumpStartLaunchedRegionInfo(
87+
region_name="us-west-1",
88+
content_bucket="jumpstart-cache-prod-us-west-1",
89+
),
90+
JumpStartLaunchedRegionInfo(
91+
region_name="ap-southeast-1",
92+
content_bucket="jumpstart-cache-prod-ap-southeast-1",
93+
),
94+
JumpStartLaunchedRegionInfo(
95+
region_name="ap-southeast-2",
96+
content_bucket="jumpstart-cache-prod-ap-southeast-2",
97+
),
98+
JumpStartLaunchedRegionInfo(
99+
region_name="ca-central-1",
100+
content_bucket="jumpstart-cache-prod-ca-central-1",
101+
),
102+
JumpStartLaunchedRegionInfo(
103+
region_name="cn-north-1",
104+
content_bucket="jumpstart-cache-prod-cn-north-1",
105+
),
106+
JumpStartLaunchedRegionInfo(
107+
region_name="cn-northwest-1",
108+
content_bucket="jumpstart-cache-prod-cn-northwest-1",
109+
),
110+
]
111+
)
21112

22113
JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT = {
23114
region.region_name: region for region in JUMPSTART_LAUNCHED_REGIONS

src/sagemaker/jumpstart/utils.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,10 @@
1616
import semantic_version
1717
import sagemaker
1818
from sagemaker.jumpstart import constants
19+
from sagemaker.jumpstart import accessors
1920
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId
2021

2122

22-
class SageMakerSettings(object):
23-
"""Static class for storing the SageMaker settings."""
24-
25-
_PARSED_SAGEMAKER_VERSION = ""
26-
27-
@staticmethod
28-
def set_sagemaker_version(version: str) -> None:
29-
"""Set SageMaker version."""
30-
SageMakerSettings._PARSED_SAGEMAKER_VERSION = version
31-
32-
@staticmethod
33-
def get_sagemaker_version() -> str:
34-
"""Return SageMaker version."""
35-
return SageMakerSettings._PARSED_SAGEMAKER_VERSION
36-
37-
3823
def get_jumpstart_launched_regions_message() -> str:
3924
"""Returns formatted string indicating where JumpStart is launched."""
4025
if len(constants.JUMPSTART_REGION_NAME_SET) == 0:
@@ -95,9 +80,9 @@ def get_sagemaker_version() -> str:
9580
calls ``parse_sagemaker_version`` to retrieve the version and set
9681
the constant.
9782
"""
98-
if SageMakerSettings.get_sagemaker_version() == "":
99-
SageMakerSettings.set_sagemaker_version(parse_sagemaker_version())
100-
return SageMakerSettings.get_sagemaker_version()
83+
if accessors.SageMakerSettings.get_sagemaker_version() == "":
84+
accessors.SageMakerSettings.set_sagemaker_version(parse_sagemaker_version())
85+
return accessors.SageMakerSettings.get_sagemaker_version()
10186

10287

10388
def parse_sagemaker_version() -> str:

src/sagemaker/model_uris.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Functions for generating S3 model artifact URIs for pre-built SageMaker models."""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import logging
18+
import os
19+
import re
20+
from typing import Optional
21+
22+
from sagemaker.jumpstart import utils as jumpstart_utils
23+
from sagemaker.jumpstart import accessors as jumpstart_accessors
24+
from sagemaker.jumpstart import constants as jumpstart_constants
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
def retrieve(
30+
region=jumpstart_constants.JUMPSTART_DEFAULT_REGION_NAME,
31+
model_id=None,
32+
model_version: Optional[str] = None,
33+
model_scope: Optional[str] = None,
34+
):
35+
"""Retrieves the model artifact URI for the model matching the given arguments.
36+
37+
Args:
38+
region (str): Region for which to retrieve model S3 URI.
39+
model_id (str): JumpStart model id for which to retrieve model S3 URI.
40+
model_version (str): JumpStart model version for which to retrieve model S3 URI.
41+
model_scope (str): The model type, i.e. what it is used for.
42+
Valid values: "training", "inference", "eia".
43+
Returns:
44+
str: the model artifact URI for the corresponding model.
45+
46+
Raises:
47+
ValueError: If the combination of arguments specified is not supported.
48+
"""
49+
if model_id is None or model_version is None:
50+
raise ValueError(
51+
"Must specify `model_id` and `model_version` when getting model artifact uri for "
52+
"JumpStart models. "
53+
)
54+
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
55+
region, model_id, model_version
56+
)
57+
if model_scope is None:
58+
raise ValueError(
59+
"Must specify `model_scope` argument to retrieve model artifact uri for JumpStart models."
60+
)
61+
if model_scope == "inference":
62+
model_artifact_key = model_specs.hosting_artifact_key
63+
elif model_scope == "training":
64+
if not model_specs.training_supported:
65+
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
66+
model_artifact_key = model_specs.training_artifact_key
67+
else:
68+
raise ValueError("JumpStart models only support inference and training.")
69+
70+
bucket = jumpstart_utils.get_jumpstart_content_bucket(region)
71+
72+
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
73+
74+
return model_s3_uri

0 commit comments

Comments
 (0)