Skip to content

Commit 84bf597

Browse files
committed
change: use packaging library for jumpstart versions
1 parent fa79a5d commit 84bf597

File tree

5 files changed

+45
-51
lines changed

5 files changed

+45
-51
lines changed

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def read_version():
4444
"packaging>=20.0",
4545
"pandas",
4646
"pathos",
47-
"semantic-version",
4847
]
4948

5049
# Specific use case dependencies

src/sagemaker/jumpstart/cache.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
import json
1818
import boto3
1919
import botocore
20-
import semantic_version
20+
from packaging.version import Version
21+
from packaging.specifiers import SpecifierSet
2122
from sagemaker.jumpstart.constants import (
2223
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
2324
JUMPSTART_DEFAULT_REGION_NAME,
@@ -146,7 +147,7 @@ def _get_manifest_key_from_model_id_semantic_version(
146147
) -> JumpStartVersionedModelId:
147148
"""Return model id and version in manifest that matches semantic version/id.
148149
149-
Uses ``semantic_version`` to perform version comparison. The highest model version
150+
Uses ``packaging.version`` to perform version comparison. The highest model version
150151
matching the semantic version is used, which is compatible with the SageMaker
151152
version.
152153
@@ -169,30 +170,27 @@ def _get_manifest_key_from_model_id_semantic_version(
169170
sm_version = utils.get_sagemaker_version()
170171

171172
versions_compatible_with_sagemaker = [
172-
semantic_version.Version(header.version)
173+
Version(header.version)
173174
for header in manifest.values()
174-
if header.model_id == model_id
175-
and semantic_version.Version(header.min_version) <= semantic_version.Version(sm_version)
175+
if header.model_id == model_id and Version(header.min_version) <= Version(sm_version)
176176
]
177177

178-
spec = (
179-
semantic_version.SimpleSpec("*")
180-
if version is None
181-
else semantic_version.SimpleSpec(version)
178+
sm_compatible_model_version = self._select_version(
179+
version, versions_compatible_with_sagemaker
182180
)
183181

184-
sm_compatible_model_version = spec.select(versions_compatible_with_sagemaker)
185182
if sm_compatible_model_version is not None:
186-
return JumpStartVersionedModelId(model_id, str(sm_compatible_model_version))
183+
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)
187184

188185
versions_incompatible_with_sagemaker = [
189-
semantic_version.Version(header.version)
190-
for header in manifest.values()
191-
if header.model_id == model_id
186+
Version(header.version) for header in manifest.values() if header.model_id == model_id
192187
]
193-
sm_incompatible_model_version = spec.select(versions_incompatible_with_sagemaker)
188+
sm_incompatible_model_version = self._select_version(
189+
version, versions_incompatible_with_sagemaker
190+
)
191+
194192
if sm_incompatible_model_version is not None:
195-
model_version_to_use_incompatible_with_sagemaker = str(sm_incompatible_model_version)
193+
model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version
196194
sm_version_to_use = [
197195
header.min_version
198196
for header in manifest.values()
@@ -275,6 +273,29 @@ def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModel
275273

276274
return self._get_header_impl(model_id, semantic_version_str=semantic_version_str)
277275

276+
def _select_version(
277+
self,
278+
semantic_version_str: str,
279+
available_versions: List[Version],
280+
) -> Optional[Version]:
281+
"""Utility to select appropriate version from available version given
282+
a semantic version with which to filter.
283+
284+
Args:
285+
semantic_version_str (str): the semantic version for which to filter
286+
available versions.
287+
available_versions (List[Version]): list of available versions.
288+
"""
289+
if semantic_version_str == "*":
290+
if len(available_versions) is 0:
291+
return None
292+
else:
293+
return str(max(available_versions))
294+
else:
295+
spec = SpecifierSet(f"=={semantic_version_str}")
296+
available_versions = list(spec.filter(available_versions))
297+
return str(available_versions[0]) if available_versions != [] else None
298+
278299
def _get_header_impl(
279300
self,
280301
model_id: str,

src/sagemaker/jumpstart/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""This module contains utilities related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
1515
from typing import Dict, List
16-
import semantic_version
16+
from packaging.version import Version
1717
import sagemaker
1818
from sagemaker.jumpstart import constants
1919
from sagemaker.jumpstart import accessors
@@ -89,14 +89,14 @@ def parse_sagemaker_version() -> str:
8989
"""Returns sagemaker library version. This should only be called once.
9090
9191
Function reads ``__version__`` variable in ``sagemaker`` module.
92-
In order to maintain compatibility with the ``semantic_version``
92+
In order to maintain compatibility with the ``packaging.version``
9393
library, versions with fewer than 2, or more than 3, periods are rejected.
94-
All versions that cannot be parsed with ``semantic_version`` are also
94+
All versions that cannot be parsed with ``packaging.version`` are also
9595
rejected.
9696
9797
Raises:
9898
RuntimeError: If the SageMaker version is not readable. An exception is also raised if
99-
the version cannot be parsed by ``semantic_version``.
99+
the version cannot be parsed by ``packaging.version``.
100100
"""
101101
version = sagemaker.__version__
102102
parsed_version = None
@@ -110,6 +110,6 @@ def parse_sagemaker_version() -> str:
110110
else:
111111
raise RuntimeError(f"Bad value for SageMaker version: {sagemaker.__version__}")
112112

113-
semantic_version.Version(parsed_version)
113+
Version(parsed_version)
114114

115115
return parsed_version

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,6 @@ def test_jumpstart_cache_get_header():
7878
model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="2.*"
7979
)
8080

81-
assert JumpStartModelHeader(
82-
{
83-
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
84-
"version": "2.0.0",
85-
"min_version": "2.49.0",
86-
"spec_key": "community_models_specs/tensorflow-ic-"
87-
"imagenet-inception-v3-classification-4/specs_v2.0.0.json",
88-
}
89-
) == cache.get_header(
90-
model_id="tensorflow-ic-imagenet-inception-v3-classification-4",
91-
semantic_version_str="2.*.*",
92-
)
93-
9481
assert JumpStartModelHeader(
9582
{
9683
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
@@ -129,19 +116,6 @@ def test_jumpstart_cache_get_header():
129116
model_id="tensorflow-ic-imagenet-inception-v3-classification-4", semantic_version_str="1.*"
130117
)
131118

132-
assert JumpStartModelHeader(
133-
{
134-
"model_id": "tensorflow-ic-imagenet-inception-v3-classification-4",
135-
"version": "1.0.0",
136-
"min_version": "2.49.0",
137-
"spec_key": "community_models_specs/tensorflow-ic-"
138-
"imagenet-inception-v3-classification-4/specs_v1.0.0.json",
139-
}
140-
) == cache.get_header(
141-
model_id="tensorflow-ic-imagenet-inception-v3-classification-4",
142-
semantic_version_str="1.*.*",
143-
)
144-
145119
with pytest.raises(KeyError) as e:
146120
cache.get_header(
147121
model_id="tensorflow-ic-imagenet-inception-v3-classification-4",
@@ -160,7 +134,7 @@ def test_jumpstart_cache_get_header():
160134
)
161135
assert "Consider upgrading" not in str(e.value)
162136

163-
with pytest.raises(ValueError):
137+
with pytest.raises(KeyError):
164138
cache.get_header(
165139
model_id="tensorflow-ic-imagenet-inception-v3-classification-4",
166140
semantic_version_str="BAD",
@@ -615,5 +589,5 @@ def test_jumpstart_cache_get_specs():
615589
with pytest.raises(KeyError):
616590
cache.get_specs(model_id=model_id, semantic_version_str="9.*")
617591

618-
with pytest.raises(ValueError):
592+
with pytest.raises(KeyError):
619593
cache.get_specs(model_id=model_id, semantic_version_str="BAD")

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ passenv =
6969
commands =
7070
python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')"
7171
pytest --cov=sagemaker --cov-append {posargs}
72-
{env:IGNORE_COVERAGE:} coverage report -i --fail-under=86
72+
{env:IGNORE_COVERAGE:} coverage report -i
7373
deps = .[test]
7474
depends =
7575
{py36,py37,py38}: clean

0 commit comments

Comments
 (0)