Skip to content

Commit 840f3a1

Browse files
rsareddy0329Roja Reddy Sareddy
andauthored
feat: Evaluator handshake to trainer (#5420)
* feat: Add support to trainer object for model parameter in Evaluator * feat: Evaluator handshake with trainer --------- Co-authored-by: Roja Reddy Sareddy <rsareddy@amazon.com>
1 parent 4055fcf commit 840f3a1

File tree

9 files changed

+125
-45
lines changed

9 files changed

+125
-45
lines changed

sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from dataclasses import dataclass
1414
from enum import Enum
1515
import re
16+
from sagemaker.train.base_trainer import BaseTrainer
17+
from sagemaker.core.utils.utils import Unassigned
1618

1719

1820
class _ModelType(Enum):
@@ -65,14 +67,14 @@ def __init__(self, sagemaker_session=None):
6567

6668
def resolve_model_info(
6769
self,
68-
base_model: Union[str, 'ModelPackage'],
70+
base_model: Union[str, BaseTrainer, 'ModelPackage'],
6971
hub_name: Optional[str] = None
7072
) -> _ModelInfo:
7173
"""
7274
Resolve model information from various input types.
7375
7476
Args:
75-
base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN
77+
base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN or BaseTrainer object with a completed job
7678
hub_name: Optional hub name for JumpStart models (defaults to SageMakerPublicHub)
7779
7880
Returns:
@@ -88,6 +90,17 @@ def resolve_model_info(
8890
return self._resolve_model_package_arn(base_model)
8991
else:
9092
return self._resolve_jumpstart_model(base_model, hub_name or self.DEFAULT_HUB_NAME)
93+
# Handle BaseTrainer type
94+
elif isinstance(base_model, BaseTrainer):
95+
if hasattr(base_model, '_latest_training_job') and hasattr(base_model._latest_training_job,
96+
'output_model_package_arn'):
97+
arn = base_model._latest_training_job.output_model_package_arn
98+
if not isinstance(arn, Unassigned):
99+
return self._resolve_model_package_arn(arn)
100+
else:
101+
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
102+
else:
103+
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
91104
else:
92105
# Not a string, so assume it's a ModelPackage object
93106
# Check if it has the expected attributes of a ModelPackage

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313

1414
from pydantic import BaseModel, validator
1515

16-
from sagemaker.core.resources import ModelPackageGroup
16+
from sagemaker.core.resources import ModelPackageGroup, ModelPackage
1717
from sagemaker.core.shapes import VpcConfig
1818

1919
if TYPE_CHECKING:
2020
from sagemaker.core.helper.session_helper import Session
2121

22+
from sagemaker.train.base_trainer import BaseTrainer
2223
# Module-level logger
2324
_logger = logging.getLogger(__name__)
2425

@@ -53,6 +54,7 @@ class BaseEvaluator(BaseModel):
5354
- JumpStart model ID (str): e.g., 'llama3-2-1b-instruct'
5455
- ModelPackage object: A fine-tuned model package
5556
- ModelPackage ARN (str): e.g., 'arn:aws:sagemaker:region:account:model-package/name/version'
57+
- BaseTrainer object: A completed training job (i.e., it must have _latest_training_job with output_model_package_arn populated)
5658
base_eval_name (Optional[str]): Optional base name for evaluation jobs. This name is used
5759
as the PipelineExecutionDisplayName when creating the SageMaker pipeline execution.
5860
The actual display name will be "{base_eval_name}-{timestamp}". This parameter can
@@ -86,7 +88,7 @@ class BaseEvaluator(BaseModel):
8688

8789
region: Optional[str] = None
8890
sagemaker_session: Optional[Any] = None
89-
model: Union[str, Any]
91+
model: Union[str, BaseTrainer, ModelPackage]
9092
base_eval_name: Optional[str] = None
9193
s3_output_path: str
9294
mlflow_resource_arn: Optional[str] = None
@@ -278,7 +280,7 @@ def _validate_mlflow_arn_format(cls, v: Optional[str]) -> Optional[str]:
278280
return v
279281

280282
@validator('model')
281-
def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any]:
283+
def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: dict) -> Union[str, Any]:
282284
"""Resolve model information from various input types.
283285
284286
This validator uses the common model resolution utility to extract:
@@ -289,7 +291,7 @@ def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any
289291
The resolved information is stored in private attributes for use by subclasses.
290292
291293
Args:
292-
v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, or ARN).
294+
v (Union[str, BaseTrainer, ModelPackage]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer).
293295
values (dict): Dictionary of already-validated fields.
294296
295297
Returns:

sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -300,18 +300,10 @@ class BenchMarkEvaluator(BaseEvaluator):
300300
"""
301301

302302
benchmark: _Benchmark
303-
dataset: Union[str, Any] # Required field, must come before optional fields
304303
subtasks: Optional[Union[str, List[str]]] = None
305304
evaluate_base_model: bool = True
306305
_hyperparameters: Optional[Any] = None
307-
308-
@validator('dataset', pre=True)
309-
def _resolve_dataset(cls, v):
310-
"""Resolve dataset to string (S3 URI or ARN) and validate format.
311-
312-
Uses BaseEvaluator's common validation logic to avoid code duplication.
313-
"""
314-
return BaseEvaluator._validate_and_resolve_dataset(v)
306+
315307

316308
@validator('benchmark')
317309
def _validate_benchmark_model_compatibility(cls, v, values):

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
286286
except TimeoutExceededError as e:
287287
logger.error("Error: %s", e)
288288

289-
self.latest_training_job = training_job
289+
self._latest_training_job = training_job
290290
return training_job
291291

292292
def _process_hyperparameters(self):

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,5 +274,5 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
274274
except TimeoutExceededError as e:
275275
logger.error("Error: %s", e)
276276

277-
self.latest_training_job = training_job
277+
self._latest_training_job = training_job
278278
return training_job

sagemaker-train/src/sagemaker/train/sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
268268
except TimeoutExceededError as e:
269269
logger.error("Error: %s", e)
270270

271-
self.latest_training_job = training_job
271+
self._latest_training_job = training_job
272272
return training_job
273273

274274

sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
_ModelResolver,
2525
_resolve_base_model,
2626
)
27+
from sagemaker.train.base_trainer import BaseTrainer
28+
from sagemaker.core.utils.utils import Unassigned
2729

2830

2931
class TestModelType:
@@ -557,3 +559,74 @@ def test_resolve_base_model_with_hub_name(self, mock_resolver_class):
557559
_resolve_base_model("test-model", hub_name="CustomHub")
558560

559561
mock_resolver.resolve_model_info.assert_called_once_with("test-model", "CustomHub")
562+
563+
564+
class TestBaseTrainerHandling:
565+
"""Tests for BaseTrainer model handling in _resolve_base_model."""
566+
567+
def test_base_trainer_with_valid_training_job(self):
568+
"""Test BaseTrainer with valid completed training job."""
569+
# Create concrete BaseTrainer subclass for testing
570+
class TestTrainer(BaseTrainer):
571+
def train(self, input_data_config, wait=True, logs=True):
572+
pass
573+
574+
mock_trainer = TestTrainer()
575+
mock_training_job = MagicMock()
576+
mock_training_job.output_model_package_arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-package/1"
577+
mock_trainer._latest_training_job = mock_training_job
578+
579+
with patch('sagemaker.train.common_utils.model_resolution._ModelResolver._resolve_model_package_arn') as mock_resolve_arn:
580+
mock_resolve_arn.return_value = MagicMock()
581+
582+
result = _resolve_base_model(mock_trainer)
583+
584+
# Verify model package ARN resolution was called
585+
mock_resolve_arn.assert_called_once_with(
586+
"arn:aws:sagemaker:us-west-2:123456789012:model-package/my-package/1"
587+
)
588+
589+
def test_base_trainer_with_unassigned_arn(self):
590+
"""Test BaseTrainer with Unassigned output_model_package_arn raises error."""
591+
# Create concrete BaseTrainer subclass for testing
592+
class TestTrainer(BaseTrainer):
593+
def train(self, input_data_config, wait=True, logs=True):
594+
pass
595+
596+
mock_trainer = TestTrainer()
597+
mock_training_job = MagicMock()
598+
mock_training_job.output_model_package_arn = Unassigned()
599+
mock_trainer._latest_training_job = mock_training_job
600+
601+
with pytest.raises(ValueError, match="BaseTrainer must have completed training job"):
602+
_resolve_base_model(mock_trainer)
603+
604+
def test_base_trainer_without_training_job(self):
605+
"""Test BaseTrainer without _latest_training_job raises error."""
606+
# Create concrete BaseTrainer subclass for testing
607+
class TestTrainer(BaseTrainer):
608+
def train(self, input_data_config, wait=True, logs=True):
609+
pass
610+
611+
mock_trainer = TestTrainer()
612+
# Don't set _latest_training_job attribute at all
613+
614+
with pytest.raises(ValueError, match="BaseTrainer must have completed training job"):
615+
_resolve_base_model(mock_trainer)
616+
617+
def test_base_trainer_without_output_model_package_arn_attribute(self):
618+
"""Test BaseTrainer with training job but missing output_model_package_arn attribute."""
619+
# Create concrete BaseTrainer subclass for testing
620+
class TestTrainer(BaseTrainer):
621+
def train(self, input_data_config, wait=True, logs=True):
622+
pass
623+
624+
# Create a simple object without output_model_package_arn
625+
class TrainingJobWithoutArn:
626+
pass
627+
628+
mock_trainer = TestTrainer()
629+
mock_trainer._latest_training_job = TrainingJobWithoutArn()
630+
631+
with pytest.raises(ValueError, match="BaseTrainer must have completed training job"):
632+
_resolve_base_model(mock_trainer)

sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from sagemaker.core.shapes import VpcConfig
2121
from sagemaker.core.resources import ModelPackageGroup, Artifact
2222
from sagemaker.core.shapes import ArtifactSource, ArtifactSourceType
23+
from sagemaker.core.utils.utils import Unassigned
24+
from sagemaker.train.base_trainer import BaseTrainer
2325

2426
from sagemaker.train.evaluate.base_evaluator import BaseEvaluator
2527

0 commit comments

Comments
 (0)