Skip to content

Commit 0aed8fa

Browse files
committed
add evaluator tagging for jumpstart models
1 parent 194e74b commit 0aed8fa

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010
import logging
1111
import re
12-
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
12+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
1313

1414
from pydantic import BaseModel, validator
1515

16+
from sagemaker.core.common_utils import TagsDict
1617
from sagemaker.core.resources import ModelPackageGroup
1718
from sagemaker.core.shapes import VpcConfig
1819

@@ -411,6 +412,13 @@ def _source_model_package_arn(self) -> Optional[str]:
411412
"""Get the resolved source model package ARN (None for JumpStart models)."""
412413
info = self._get_resolved_model_info()
413414
return info.source_model_package_arn if info else None
415+
416+
@property
417+
def _is_jumpstart_model(self) -> bool:
418+
"""Determine if model is a JumpStart model"""
419+
from sagemaker.train.common_utils.model_resolution import _ModelType
420+
info = self._get_resolved_model_info()
421+
return info.model_type == _ModelType.JUMPSTART
414422

415423
def _infer_model_package_group_arn(self) -> Optional[str]:
416424
"""Infer model package group ARN from source model package ARN.
@@ -795,6 +803,12 @@ def _start_execution(
795803
EvaluationPipelineExecution: Started execution object
796804
"""
797805
from .execution import EvaluationPipelineExecution
806+
807+
tags: List[TagsDict] = []
808+
809+
if self._is_jumpstart_model:
810+
from sagemaker.core.jumpstart.utils import add_jumpstart_model_info_tags
811+
tags = add_jumpstart_model_info_tags(tags, self.model, "*")
798812

799813
execution = EvaluationPipelineExecution.start(
800814
eval_type=eval_type,
@@ -803,7 +817,8 @@ def _start_execution(
803817
role_arn=role_arn,
804818
s3_output_path=self.s3_output_path,
805819
session=self.sagemaker_session.boto_session if hasattr(self.sagemaker_session, 'boto_session') else None,
806-
region=region
820+
region=region,
821+
tags=tags
807822
)
808823

809824
return execution

sagemaker-train/src/sagemaker/train/evaluate/execution.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# Third-party imports
1717
from botocore.exceptions import ClientError
1818
from pydantic import BaseModel, Field
19+
from sagemaker.core.common_utils import TagsDict
1920
from sagemaker.core.helper.session_helper import Session
2021
from sagemaker.core.resources import Pipeline, PipelineExecution, Tag
2122
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
@@ -38,6 +39,7 @@ def _create_evaluation_pipeline(
3839
pipeline_definition: str,
3940
session: Optional[Any] = None,
4041
region: Optional[str] = None,
42+
tags: Optional[List[TagsDict]] = [],
4143
) -> Any:
4244
"""Helper method to create a SageMaker pipeline for evaluation.
4345
@@ -49,6 +51,7 @@ def _create_evaluation_pipeline(
4951
pipeline_definition (str): JSON pipeline definition (Jinja2 template).
5052
session (Optional[Any]): SageMaker session object.
5153
region (Optional[str]): AWS region.
54+
tags (Optional[List[TagsDict]]): List of tags to include in pipeline
5255
5356
Returns:
5457
Any: Created Pipeline instance (ready for execution).
@@ -65,9 +68,9 @@ def _create_evaluation_pipeline(
6568
resolved_pipeline_definition = template.render(pipeline_name=pipeline_name)
6669

6770
# Create tags for the pipeline
68-
tags = [
71+
tags = tags.extend([
6972
{"key": _TAG_SAGEMAKER_MODEL_EVALUATION, "value": "true"}
70-
]
73+
])
7174

7275
pipeline = Pipeline.create(
7376
pipeline_name=pipeline_name,
@@ -163,7 +166,8 @@ def _get_or_create_pipeline(
163166
pipeline_definition: str,
164167
role_arn: str,
165168
session: Optional[Session] = None,
166-
region: Optional[str] = None
169+
region: Optional[str] = None,
170+
tags: Optional[List[TagsDict]] = [],
167171
) -> Pipeline:
168172
"""Get existing pipeline or create/update it.
169173
@@ -177,6 +181,7 @@ def _get_or_create_pipeline(
177181
role_arn: IAM role ARN for pipeline execution
178182
session: Boto3 session (optional)
179183
region: AWS region (optional)
184+
tags (Optional[List[TagsDict]]): List of tags to include in pipeline
180185
181186
Returns:
182187
Pipeline instance (existing updated or newly created)
@@ -202,7 +207,7 @@ def _get_or_create_pipeline(
202207

203208
# Get tags using Tag.get_all
204209
tags_list = Tag.get_all(resource_arn=pipeline_arn, session=session, region=region)
205-
tags = {tag.key: tag.value for tag in tags_list}
210+
tags = tags.extend({tag.key: tag.value for tag in tags_list})
206211

207212
# Validate tag
208213
if tags.get(_TAG_SAGEMAKER_MODEL_EVALUATION) == "true":
@@ -505,7 +510,8 @@ def start(
505510
role_arn: str,
506511
s3_output_path: Optional[str] = None,
507512
session: Optional[Session] = None,
508-
region: Optional[str] = None
513+
region: Optional[str] = None,
514+
tags: Optional[List[TagsDict]] = [],
509515
) -> 'EvaluationPipelineExecution':
510516
"""Create sagemaker pipeline execution. Optionally creates pipeline.
511517
@@ -517,6 +523,7 @@ def start(
517523
s3_output_path (Optional[str]): S3 location where evaluation results are stored.
518524
session (Optional[Session]): Boto3 session for API calls.
519525
region (Optional[str]): AWS region for the pipeline.
526+
tags (Optional[List[TagsDict]]): List of tags to include in pipeline
520527
521528
Returns:
522529
EvaluationPipelineExecution: Started pipeline execution instance.
@@ -547,7 +554,8 @@ def start(
547554
pipeline_definition=pipeline_definition,
548555
role_arn=role_arn,
549556
session=session,
550-
region=region
557+
region=region,
558+
tags=tags,
551559
)
552560

553561
# Start pipeline execution via boto3

0 commit comments

Comments
 (0)