Skip to content

Commit 9bb96bf

Browse files
author
Roja Reddy Sareddy
committed
feat: Add support to trainer object for model parameter in Evaluator
1 parent fb0d789 commit 9bb96bf

File tree

4 files changed

+128
-38
lines changed

4 files changed

+128
-38
lines changed

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515

1616
from sagemaker.core.resources import ModelPackageGroup
1717
from sagemaker.core.shapes import VpcConfig
18+
from sagemaker.core.utils.utils import Unassigned
1819

1920
if TYPE_CHECKING:
2021
from sagemaker.core.helper.session_helper import Session
22+
from sagemaker.train.base_trainer import BaseTrainer
2123

2224
# Module-level logger
2325
_logger = logging.getLogger(__name__)
@@ -278,7 +280,7 @@ def _validate_mlflow_arn_format(cls, v: Optional[str]) -> Optional[str]:
278280
return v
279281

280282
@validator('model')
281-
def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any]:
283+
def _resolve_model_info(cls, v: Union[str, "BaseTrainer", Any], values: dict) -> Union[str, Any]:
282284
"""Resolve model information from various input types.
283285
284286
This validator uses the common model resolution utility to extract:
@@ -289,7 +291,7 @@ def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any
289291
The resolved information is stored in private attributes for use by subclasses.
290292
291293
Args:
292-
v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, or ARN).
294+
v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer).
293295
values (dict): Dictionary of already-validated fields.
294296
295297
Returns:
@@ -302,12 +304,25 @@ def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any
302304
import os
303305

304306
try:
307+
# Handle BaseTrainer type
308+
if hasattr(v, '__class__') and v.__class__.__name__ == 'BaseTrainer' or hasattr(v, '_latest_training_job'):
309+
if hasattr(v._latest_training_job, 'output_model_package_arn'):
310+
arn = v._latest_training_job.output_model_package_arn
311+
if not isinstance(arn, Unassigned):
312+
model_to_resolve = arn
313+
else:
314+
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
315+
else:
316+
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
317+
else:
318+
model_to_resolve = v
319+
305320
# Get the session for resolution (may not be created yet due to validator order)
306321
session = values.get('sagemaker_session')
307322

308323
# Resolve model information
309324
model_info = _resolve_base_model(
310-
base_model=v,
325+
base_model=model_to_resolve,
311326
sagemaker_session=session
312327
)
313328

sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,18 +300,10 @@ class BenchMarkEvaluator(BaseEvaluator):
300300
"""
301301

302302
benchmark: _Benchmark
303-
dataset: Union[str, Any] # Required field, must come before optional fields
304303
subtasks: Optional[Union[str, List[str]]] = None
305304
evaluate_base_model: bool = True
306305
_hyperparameters: Optional[Any] = None
307306

308-
@validator('dataset', pre=True)
309-
def _resolve_dataset(cls, v):
310-
"""Resolve dataset to string (S3 URI or ARN) and validate format.
311-
312-
Uses BaseEvaluator's common validation logic to avoid code duplication.
313-
"""
314-
return BaseEvaluator._validate_and_resolve_dataset(v)
315307

316308
@validator('benchmark')
317309
def _validate_benchmark_model_compatibility(cls, v, values):

sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from sagemaker.core.shapes import VpcConfig
2121
from sagemaker.core.resources import ModelPackageGroup, Artifact
2222
from sagemaker.core.shapes import ArtifactSource, ArtifactSourceType
23+
from sagemaker.core.utils.utils import Unassigned
24+
from sagemaker.train.base_trainer import BaseTrainer
2325

2426
from sagemaker.train.evaluate.base_evaluator import BaseEvaluator
2527

@@ -1291,3 +1293,86 @@ def test_with_all_optional_params(self, mock_resolve, mock_session, mock_model_i
12911293
assert evaluator.networking == vpc_config
12921294
assert evaluator.kms_key_id == "arn:aws:kms:us-west-2:123456789012:key/12345"
12931295
assert evaluator.region == DEFAULT_REGION
1296+
1297+
1298+
class TestBaseTrainerHandling:
1299+
"""Tests for BaseTrainer model handling."""
1300+
1301+
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
1302+
def test_base_trainer_with_valid_training_job(self, mock_resolve, mock_session, mock_model_info_with_package):
1303+
"""Test BaseTrainer with valid completed training job."""
1304+
mock_resolve.return_value = mock_model_info_with_package
1305+
1306+
# Create mock BaseTrainer with completed training job
1307+
mock_trainer = MagicMock(spec=BaseTrainer)
1308+
mock_training_job = MagicMock()
1309+
mock_training_job.output_model_package_arn = DEFAULT_MODEL_PACKAGE_ARN
1310+
mock_trainer._latest_training_job = mock_training_job
1311+
1312+
evaluator = BaseEvaluator(
1313+
model=mock_trainer,
1314+
s3_output_path=DEFAULT_S3_OUTPUT,
1315+
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1316+
sagemaker_session=mock_session,
1317+
)
1318+
1319+
# Verify model resolution was called with the training job's model package ARN
1320+
mock_resolve.assert_called_once_with(
1321+
base_model=DEFAULT_MODEL_PACKAGE_ARN,
1322+
sagemaker_session=mock_session
1323+
)
1324+
assert evaluator.model == mock_trainer
1325+
1326+
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
1327+
def test_base_trainer_with_unassigned_arn(self, mock_resolve, mock_session):
1328+
"""Test BaseTrainer with Unassigned output_model_package_arn raises error."""
1329+
# Create mock BaseTrainer with Unassigned ARN
1330+
mock_trainer = MagicMock(spec=BaseTrainer)
1331+
mock_training_job = MagicMock()
1332+
mock_training_job.output_model_package_arn = Unassigned()
1333+
mock_trainer._latest_training_job = mock_training_job
1334+
1335+
with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"):
1336+
BaseEvaluator(
1337+
model=mock_trainer,
1338+
s3_output_path=DEFAULT_S3_OUTPUT,
1339+
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1340+
sagemaker_session=mock_session,
1341+
)
1342+
1343+
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
1344+
def test_base_trainer_without_training_job(self, mock_resolve, mock_session):
1345+
"""Test BaseTrainer without _latest_training_job falls through to normal processing."""
1346+
# Create mock BaseTrainer without _latest_training_job attribute
1347+
mock_trainer = MagicMock()
1348+
mock_trainer.__class__.__name__ = 'BaseTrainer'
1349+
# Don't set _latest_training_job attribute at all
1350+
1351+
# This should fail during model resolution, not in BaseTrainer handling
1352+
with pytest.raises(ValidationError, match="Failed to resolve model"):
1353+
BaseEvaluator(
1354+
model=mock_trainer,
1355+
s3_output_path=DEFAULT_S3_OUTPUT,
1356+
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1357+
sagemaker_session=mock_session,
1358+
)
1359+
1360+
def test_base_trainer_without_output_model_package_arn_attribute(self, mock_session):
1361+
"""Test BaseTrainer with training job but missing output_model_package_arn attribute."""
1362+
1363+
# Create a custom class that doesn't have output_model_package_arn
1364+
class MockTrainingJobWithoutArn:
1365+
pass
1366+
1367+
# Create mock BaseTrainer with _latest_training_job but no output_model_package_arn
1368+
mock_trainer = MagicMock()
1369+
mock_trainer.__class__.__name__ = 'BaseTrainer'
1370+
mock_trainer._latest_training_job = MockTrainingJobWithoutArn()
1371+
1372+
with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"):
1373+
BaseEvaluator(
1374+
model=mock_trainer,
1375+
s3_output_path=DEFAULT_S3_OUTPUT,
1376+
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1377+
sagemaker_session=mock_session,
1378+
)

0 commit comments

Comments
 (0)