Skip to content

Commit 6dc76a3

Browse files
author
Roja Reddy Sareddy
committed
feat: Evaluator handshake with trainer
1 parent cb699a8 commit 6dc76a3

File tree

8 files changed

+99
-109
lines changed

8 files changed

+99
-109
lines changed

sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from dataclasses import dataclass
1414
from enum import Enum
1515
import re
16+
from sagemaker.train.base_trainer import BaseTrainer
17+
from sagemaker.core.utils.utils import Unassigned
1618

1719

1820
class _ModelType(Enum):
@@ -65,14 +67,14 @@ def __init__(self, sagemaker_session=None):
6567

6668
def resolve_model_info(
6769
self,
68-
base_model: Union[str, 'ModelPackage'],
70+
base_model: Union[str, BaseTrainer, 'ModelPackage'],
6971
hub_name: Optional[str] = None
7072
) -> _ModelInfo:
7173
"""
7274
Resolve model information from various input types.
7375
7476
Args:
75-
base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN
77+
base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN or BaseTrainer object with a completed job
7678
hub_name: Optional hub name for JumpStart models (defaults to SageMakerPublicHub)
7779
7880
Returns:
@@ -88,6 +90,17 @@ def resolve_model_info(
8890
return self._resolve_model_package_arn(base_model)
8991
else:
9092
return self._resolve_jumpstart_model(base_model, hub_name or self.DEFAULT_HUB_NAME)
93+
# Handle BaseTrainer type
94+
elif isinstance(base_model, BaseTrainer):
95+
if hasattr(base_model, '_latest_training_job') and hasattr(base_model._latest_training_job,
96+
'output_model_package_arn'):
97+
arn = base_model._latest_training_job.output_model_package_arn
98+
if not isinstance(arn, Unassigned):
99+
return self._resolve_model_package_arn(arn)
100+
else:
101+
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
102+
else:
103+
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
91104
else:
92105
# Not a string, so assume it's a ModelPackage object
93106
# Check if it has the expected attributes of a ModelPackage

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313

1414
from pydantic import BaseModel, validator
1515

16-
from sagemaker.core.resources import ModelPackageGroup
16+
from sagemaker.core.resources import ModelPackageGroup, ModelPackage
1717
from sagemaker.core.shapes import VpcConfig
18-
from sagemaker.core.utils.utils import Unassigned
1918

2019
if TYPE_CHECKING:
2120
from sagemaker.core.helper.session_helper import Session
22-
from sagemaker.train.base_trainer import BaseTrainer
2321

22+
from sagemaker.train.base_trainer import BaseTrainer
2423
# Module-level logger
2524
_logger = logging.getLogger(__name__)
2625

@@ -55,6 +54,7 @@ class BaseEvaluator(BaseModel):
5554
- JumpStart model ID (str): e.g., 'llama3-2-1b-instruct'
5655
- ModelPackage object: A fine-tuned model package
5756
- ModelPackage ARN (str): e.g., 'arn:aws:sagemaker:region:account:model-package/name/version'
57+
- BaseTrainer object: A completed training job (i.e., it must have _latest_training_job with output_model_package_arn populated)
5858
base_eval_name (Optional[str]): Optional base name for evaluation jobs. This name is used
5959
as the PipelineExecutionDisplayName when creating the SageMaker pipeline execution.
6060
The actual display name will be "{base_eval_name}-{timestamp}". This parameter can
@@ -88,7 +88,7 @@ class BaseEvaluator(BaseModel):
8888

8989
region: Optional[str] = None
9090
sagemaker_session: Optional[Any] = None
91-
model: Union[str, Any]
91+
model: Union[str, BaseTrainer, ModelPackage]
9292
base_eval_name: Optional[str] = None
9393
s3_output_path: str
9494
mlflow_resource_arn: Optional[str] = None
@@ -280,7 +280,7 @@ def _validate_mlflow_arn_format(cls, v: Optional[str]) -> Optional[str]:
280280
return v
281281

282282
@validator('model')
283-
def _resolve_model_info(cls, v: Union[str, "BaseTrainer", Any], values: dict) -> Union[str, Any]:
283+
def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: dict) -> Union[str, Any]:
284284
"""Resolve model information from various input types.
285285
286286
This validator uses the common model resolution utility to extract:
@@ -291,7 +291,7 @@ def _resolve_model_info(cls, v: Union[str, "BaseTrainer", Any], values: dict) ->
291291
The resolved information is stored in private attributes for use by subclasses.
292292
293293
Args:
294-
v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer).
294+
v (Union[str, BaseTrainer, ModelPackage]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer).
295295
values (dict): Dictionary of already-validated fields.
296296
297297
Returns:
@@ -304,25 +304,12 @@ def _resolve_model_info(cls, v: Union[str, "BaseTrainer", Any], values: dict) ->
304304
import os
305305

306306
try:
307-
# Handle BaseTrainer type
308-
if hasattr(v, '__class__') and v.__class__.__name__ == 'BaseTrainer' or hasattr(v, '_latest_training_job'):
309-
if hasattr(v._latest_training_job, 'output_model_package_arn'):
310-
arn = v._latest_training_job.output_model_package_arn
311-
if not isinstance(arn, Unassigned):
312-
model_to_resolve = arn
313-
else:
314-
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
315-
else:
316-
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
317-
else:
318-
model_to_resolve = v
319-
320307
# Get the session for resolution (may not be created yet due to validator order)
321308
session = values.get('sagemaker_session')
322309

323310
# Resolve model information
324311
model_info = _resolve_base_model(
325-
base_model=model_to_resolve,
312+
base_model=v,
326313
sagemaker_session=session
327314
)
328315

sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ class BenchMarkEvaluator(BaseEvaluator):
303303
subtasks: Optional[Union[str, List[str]]] = None
304304
evaluate_base_model: bool = True
305305
_hyperparameters: Optional[Any] = None
306-
306+
307307

308308
@validator('benchmark')
309309
def _validate_benchmark_model_compatibility(cls, v, values):

sagemaker-train/src/sagemaker/train/rlaif_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
286286
except TimeoutExceededError as e:
287287
logger.error("Error: %s", e)
288288

289-
self.latest_training_job = training_job
289+
self._latest_training_job = training_job
290290
return training_job
291291

292292
def _process_hyperparameters(self):

sagemaker-train/src/sagemaker/train/rlvr_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,5 +274,5 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
274274
except TimeoutExceededError as e:
275275
logger.error("Error: %s", e)
276276

277-
self.latest_training_job = training_job
277+
self._latest_training_job = training_job
278278
return training_job

sagemaker-train/src/sagemaker/train/sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
268268
except TimeoutExceededError as e:
269269
logger.error("Error: %s", e)
270270

271-
self.latest_training_job = training_job
271+
self._latest_training_job = training_job
272272
return training_job
273273

274274

sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
_ModelResolver,
2525
_resolve_base_model,
2626
)
27+
from sagemaker.train.base_trainer import BaseTrainer
28+
from sagemaker.core.utils.utils import Unassigned
2729

2830

2931
class TestModelType:
@@ -557,3 +559,74 @@ def test_resolve_base_model_with_hub_name(self, mock_resolver_class):
557559
_resolve_base_model("test-model", hub_name="CustomHub")
558560

559561
mock_resolver.resolve_model_info.assert_called_once_with("test-model", "CustomHub")
562+
563+
564+
class TestBaseTrainerHandling:
565+
"""Tests for BaseTrainer model handling in _resolve_base_model."""
566+
567+
def test_base_trainer_with_valid_training_job(self):
568+
"""Test BaseTrainer with valid completed training job."""
569+
# Create concrete BaseTrainer subclass for testing
570+
class TestTrainer(BaseTrainer):
571+
def train(self, input_data_config, wait=True, logs=True):
572+
pass
573+
574+
mock_trainer = TestTrainer()
575+
mock_training_job = MagicMock()
576+
mock_training_job.output_model_package_arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-package/1"
577+
mock_trainer._latest_training_job = mock_training_job
578+
579+
with patch('sagemaker.train.common_utils.model_resolution._ModelResolver._resolve_model_package_arn') as mock_resolve_arn:
580+
mock_resolve_arn.return_value = MagicMock()
581+
582+
result = _resolve_base_model(mock_trainer)
583+
584+
# Verify model package ARN resolution was called
585+
mock_resolve_arn.assert_called_once_with(
586+
"arn:aws:sagemaker:us-west-2:123456789012:model-package/my-package/1"
587+
)
588+
589+
def test_base_trainer_with_unassigned_arn(self):
590+
"""Test BaseTrainer with Unassigned output_model_package_arn raises error."""
591+
# Create concrete BaseTrainer subclass for testing
592+
class TestTrainer(BaseTrainer):
593+
def train(self, input_data_config, wait=True, logs=True):
594+
pass
595+
596+
mock_trainer = TestTrainer()
597+
mock_training_job = MagicMock()
598+
mock_training_job.output_model_package_arn = Unassigned()
599+
mock_trainer._latest_training_job = mock_training_job
600+
601+
with pytest.raises(ValueError, match="BaseTrainer must have completed training job"):
602+
_resolve_base_model(mock_trainer)
603+
604+
def test_base_trainer_without_training_job(self):
605+
"""Test BaseTrainer without _latest_training_job raises error."""
606+
# Create concrete BaseTrainer subclass for testing
607+
class TestTrainer(BaseTrainer):
608+
def train(self, input_data_config, wait=True, logs=True):
609+
pass
610+
611+
mock_trainer = TestTrainer()
612+
# Don't set _latest_training_job attribute at all
613+
614+
with pytest.raises(ValueError, match="BaseTrainer must have completed training job"):
615+
_resolve_base_model(mock_trainer)
616+
617+
def test_base_trainer_without_output_model_package_arn_attribute(self):
618+
"""Test BaseTrainer with training job but missing output_model_package_arn attribute."""
619+
# Create concrete BaseTrainer subclass for testing
620+
class TestTrainer(BaseTrainer):
621+
def train(self, input_data_config, wait=True, logs=True):
622+
pass
623+
624+
# Create a simple object without output_model_package_arn
625+
class TrainingJobWithoutArn:
626+
pass
627+
628+
mock_trainer = TestTrainer()
629+
mock_trainer._latest_training_job = TrainingJobWithoutArn()
630+
631+
with pytest.raises(ValueError, match="BaseTrainer must have completed training job"):
632+
_resolve_base_model(mock_trainer)

sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,86 +1293,3 @@ def test_with_all_optional_params(self, mock_resolve, mock_session, mock_model_i
12931293
assert evaluator.networking == vpc_config
12941294
assert evaluator.kms_key_id == "arn:aws:kms:us-west-2:123456789012:key/12345"
12951295
assert evaluator.region == DEFAULT_REGION
1296-
1297-
1298-
class TestBaseTrainerHandling:
1299-
"""Tests for BaseTrainer model handling."""
1300-
1301-
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
1302-
def test_base_trainer_with_valid_training_job(self, mock_resolve, mock_session, mock_model_info_with_package):
1303-
"""Test BaseTrainer with valid completed training job."""
1304-
mock_resolve.return_value = mock_model_info_with_package
1305-
1306-
# Create mock BaseTrainer with completed training job
1307-
mock_trainer = MagicMock(spec=BaseTrainer)
1308-
mock_training_job = MagicMock()
1309-
mock_training_job.output_model_package_arn = DEFAULT_MODEL_PACKAGE_ARN
1310-
mock_trainer._latest_training_job = mock_training_job
1311-
1312-
evaluator = BaseEvaluator(
1313-
model=mock_trainer,
1314-
s3_output_path=DEFAULT_S3_OUTPUT,
1315-
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1316-
sagemaker_session=mock_session,
1317-
)
1318-
1319-
# Verify model resolution was called with the training job's model package ARN
1320-
mock_resolve.assert_called_once_with(
1321-
base_model=DEFAULT_MODEL_PACKAGE_ARN,
1322-
sagemaker_session=mock_session
1323-
)
1324-
assert evaluator.model == mock_trainer
1325-
1326-
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
1327-
def test_base_trainer_with_unassigned_arn(self, mock_resolve, mock_session):
1328-
"""Test BaseTrainer with Unassigned output_model_package_arn raises error."""
1329-
# Create mock BaseTrainer with Unassigned ARN
1330-
mock_trainer = MagicMock(spec=BaseTrainer)
1331-
mock_training_job = MagicMock()
1332-
mock_training_job.output_model_package_arn = Unassigned()
1333-
mock_trainer._latest_training_job = mock_training_job
1334-
1335-
with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"):
1336-
BaseEvaluator(
1337-
model=mock_trainer,
1338-
s3_output_path=DEFAULT_S3_OUTPUT,
1339-
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1340-
sagemaker_session=mock_session,
1341-
)
1342-
1343-
@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
1344-
def test_base_trainer_without_training_job(self, mock_resolve, mock_session):
1345-
"""Test BaseTrainer without _latest_training_job falls through to normal processing."""
1346-
# Create mock BaseTrainer without _latest_training_job attribute
1347-
mock_trainer = MagicMock()
1348-
mock_trainer.__class__.__name__ = 'BaseTrainer'
1349-
# Don't set _latest_training_job attribute at all
1350-
1351-
# This should fail during model resolution, not in BaseTrainer handling
1352-
with pytest.raises(ValidationError, match="Failed to resolve model"):
1353-
BaseEvaluator(
1354-
model=mock_trainer,
1355-
s3_output_path=DEFAULT_S3_OUTPUT,
1356-
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1357-
sagemaker_session=mock_session,
1358-
)
1359-
1360-
def test_base_trainer_without_output_model_package_arn_attribute(self, mock_session):
1361-
"""Test BaseTrainer with training job but missing output_model_package_arn attribute."""
1362-
1363-
# Create a custom class that doesn't have output_model_package_arn
1364-
class MockTrainingJobWithoutArn:
1365-
pass
1366-
1367-
# Create mock BaseTrainer with _latest_training_job but no output_model_package_arn
1368-
mock_trainer = MagicMock()
1369-
mock_trainer.__class__.__name__ = 'BaseTrainer'
1370-
mock_trainer._latest_training_job = MockTrainingJobWithoutArn()
1371-
1372-
with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"):
1373-
BaseEvaluator(
1374-
model=mock_trainer,
1375-
s3_output_path=DEFAULT_S3_OUTPUT,
1376-
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
1377-
sagemaker_session=mock_session,
1378-
)

0 commit comments

Comments
 (0)