1717import json
1818import boto3
1919import botocore
20- import semantic_version
20+ from packaging .version import Version
21+ from packaging .specifiers import SpecifierSet
2122from sagemaker .jumpstart .constants import (
2223 JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
2324 JUMPSTART_DEFAULT_REGION_NAME ,
@@ -146,7 +147,7 @@ def _get_manifest_key_from_model_id_semantic_version(
146147 ) -> JumpStartVersionedModelId :
147148 """Return model id and version in manifest that matches semantic version/id.
148149
149- Uses ``semantic_version `` to perform version comparison. The highest model version
150+ Uses ``packaging.version `` to perform version comparison. The highest model version
150151 matching the semantic version is used, which is compatible with the SageMaker
151152 version.
152153
@@ -169,30 +170,27 @@ def _get_manifest_key_from_model_id_semantic_version(
169170 sm_version = utils .get_sagemaker_version ()
170171
171172 versions_compatible_with_sagemaker = [
172- semantic_version . Version (header .version )
173+ Version (header .version )
173174 for header in manifest .values ()
174- if header .model_id == model_id
175- and semantic_version .Version (header .min_version ) <= semantic_version .Version (sm_version )
175+ if header .model_id == model_id and Version (header .min_version ) <= Version (sm_version )
176176 ]
177177
178- spec = (
179- semantic_version .SimpleSpec ("*" )
180- if version is None
181- else semantic_version .SimpleSpec (version )
178+ sm_compatible_model_version = self ._select_version (
179+ version , versions_compatible_with_sagemaker
182180 )
183181
184- sm_compatible_model_version = spec .select (versions_compatible_with_sagemaker )
185182 if sm_compatible_model_version is not None :
186- return JumpStartVersionedModelId (model_id , str ( sm_compatible_model_version ) )
183+ return JumpStartVersionedModelId (model_id , sm_compatible_model_version )
187184
188185 versions_incompatible_with_sagemaker = [
189- semantic_version .Version (header .version )
190- for header in manifest .values ()
191- if header .model_id == model_id
186+ Version (header .version ) for header in manifest .values () if header .model_id == model_id
192187 ]
193- sm_incompatible_model_version = spec .select (versions_incompatible_with_sagemaker )
188+ sm_incompatible_model_version = self ._select_version (
189+ version , versions_incompatible_with_sagemaker
190+ )
191+
194192 if sm_incompatible_model_version is not None :
195- model_version_to_use_incompatible_with_sagemaker = str ( sm_incompatible_model_version )
193+ model_version_to_use_incompatible_with_sagemaker = sm_incompatible_model_version
196194 sm_version_to_use = [
197195 header .min_version
198196 for header in manifest .values ()
@@ -275,6 +273,29 @@ def get_header(self, model_id: str, semantic_version_str: str) -> JumpStartModel
275273
276274 return self ._get_header_impl (model_id , semantic_version_str = semantic_version_str )
277275
276+ def _select_version (
277+ self ,
278+ semantic_version_str : str ,
279+ available_versions : List [Version ],
280+ ) -> Optional [Version ]:
281+ """Utility to select appropriate version from available version given
282+ a semantic version with which to filter.
283+
284+ Args:
285+ semantic_version_str (str): the semantic version for which to filter
286+ available versions.
287+ available_versions (List[Version]): list of available versions.
288+ """
289+ if semantic_version_str == "*" :
290+ if len (available_versions ) is 0 :
291+ return None
292+ else :
293+ return str (max (available_versions ))
294+ else :
295+ spec = SpecifierSet (f"=={ semantic_version_str } " )
296+ available_versions = list (spec .filter (available_versions ))
297+ return str (available_versions [0 ]) if available_versions != [] else None
298+
278299 def _get_header_impl (
279300 self ,
280301 model_id : str ,
0 commit comments