|
20 | 20 | from sagemaker.core.shapes import VpcConfig |
21 | 21 | from sagemaker.core.resources import ModelPackageGroup, Artifact |
22 | 22 | from sagemaker.core.shapes import ArtifactSource, ArtifactSourceType |
| 23 | +from sagemaker.core.utils.utils import Unassigned |
| 24 | +from sagemaker.train.base_trainer import BaseTrainer |
23 | 25 |
|
24 | 26 | from sagemaker.train.evaluate.base_evaluator import BaseEvaluator |
25 | 27 |
|
@@ -1291,3 +1293,86 @@ def test_with_all_optional_params(self, mock_resolve, mock_session, mock_model_i |
1291 | 1293 | assert evaluator.networking == vpc_config |
1292 | 1294 | assert evaluator.kms_key_id == "arn:aws:kms:us-west-2:123456789012:key/12345" |
1293 | 1295 | assert evaluator.region == DEFAULT_REGION |
| 1296 | + |
| 1297 | + |
| 1298 | +class TestBaseTrainerHandling: |
| 1299 | + """Tests for BaseTrainer model handling.""" |
| 1300 | + |
| 1301 | + @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") |
| 1302 | + def test_base_trainer_with_valid_training_job(self, mock_resolve, mock_session, mock_model_info_with_package): |
| 1303 | + """Test BaseTrainer with valid completed training job.""" |
| 1304 | + mock_resolve.return_value = mock_model_info_with_package |
| 1305 | + |
| 1306 | + # Create mock BaseTrainer with completed training job |
| 1307 | + mock_trainer = MagicMock(spec=BaseTrainer) |
| 1308 | + mock_training_job = MagicMock() |
| 1309 | + mock_training_job.output_model_package_arn = DEFAULT_MODEL_PACKAGE_ARN |
| 1310 | + mock_trainer._latest_training_job = mock_training_job |
| 1311 | + |
| 1312 | + evaluator = BaseEvaluator( |
| 1313 | + model=mock_trainer, |
| 1314 | + s3_output_path=DEFAULT_S3_OUTPUT, |
| 1315 | + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, |
| 1316 | + sagemaker_session=mock_session, |
| 1317 | + ) |
| 1318 | + |
| 1319 | + # Verify model resolution was called with the training job's model package ARN |
| 1320 | + mock_resolve.assert_called_once_with( |
| 1321 | + base_model=DEFAULT_MODEL_PACKAGE_ARN, |
| 1322 | + sagemaker_session=mock_session |
| 1323 | + ) |
| 1324 | + assert evaluator.model == mock_trainer |
| 1325 | + |
| 1326 | + @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") |
| 1327 | + def test_base_trainer_with_unassigned_arn(self, mock_resolve, mock_session): |
| 1328 | + """Test BaseTrainer with Unassigned output_model_package_arn raises error.""" |
| 1329 | + # Create mock BaseTrainer with Unassigned ARN |
| 1330 | + mock_trainer = MagicMock(spec=BaseTrainer) |
| 1331 | + mock_training_job = MagicMock() |
| 1332 | + mock_training_job.output_model_package_arn = Unassigned() |
| 1333 | + mock_trainer._latest_training_job = mock_training_job |
| 1334 | + |
| 1335 | + with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"): |
| 1336 | + BaseEvaluator( |
| 1337 | + model=mock_trainer, |
| 1338 | + s3_output_path=DEFAULT_S3_OUTPUT, |
| 1339 | + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, |
| 1340 | + sagemaker_session=mock_session, |
| 1341 | + ) |
| 1342 | + |
| 1343 | + @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") |
| 1344 | + def test_base_trainer_without_training_job(self, mock_resolve, mock_session): |
| 1345 | + """Test BaseTrainer without _latest_training_job falls through to normal processing.""" |
| 1346 | + # Create mock BaseTrainer without _latest_training_job attribute |
| 1347 | + mock_trainer = MagicMock() |
| 1348 | + mock_trainer.__class__.__name__ = 'BaseTrainer' |
| 1349 | + # Don't set _latest_training_job attribute at all |
| 1350 | + |
| 1351 | + # This should fail during model resolution, not in BaseTrainer handling |
| 1352 | + with pytest.raises(ValidationError, match="Failed to resolve model"): |
| 1353 | + BaseEvaluator( |
| 1354 | + model=mock_trainer, |
| 1355 | + s3_output_path=DEFAULT_S3_OUTPUT, |
| 1356 | + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, |
| 1357 | + sagemaker_session=mock_session, |
| 1358 | + ) |
| 1359 | + |
| 1360 | + def test_base_trainer_without_output_model_package_arn_attribute(self, mock_session): |
| 1361 | + """Test BaseTrainer with training job but missing output_model_package_arn attribute.""" |
| 1362 | + |
| 1363 | + # Create a custom class that doesn't have output_model_package_arn |
| 1364 | + class MockTrainingJobWithoutArn: |
| 1365 | + pass |
| 1366 | + |
| 1367 | + # Create mock BaseTrainer with _latest_training_job but no output_model_package_arn |
| 1368 | + mock_trainer = MagicMock() |
| 1369 | + mock_trainer.__class__.__name__ = 'BaseTrainer' |
| 1370 | + mock_trainer._latest_training_job = MockTrainingJobWithoutArn() |
| 1371 | + |
| 1372 | + with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"): |
| 1373 | + BaseEvaluator( |
| 1374 | + model=mock_trainer, |
| 1375 | + s3_output_path=DEFAULT_S3_OUTPUT, |
| 1376 | + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, |
| 1377 | + sagemaker_session=mock_session, |
| 1378 | + ) |
0 commit comments