diff --git a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py index 1c3f09a43e..2ce6ea7198 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py @@ -13,6 +13,8 @@ from dataclasses import dataclass from enum import Enum import re +from sagemaker.train.base_trainer import BaseTrainer +from sagemaker.core.utils.utils import Unassigned class _ModelType(Enum): @@ -65,14 +67,14 @@ def __init__(self, sagemaker_session=None): def resolve_model_info( self, - base_model: Union[str, 'ModelPackage'], + base_model: Union[str, BaseTrainer, 'ModelPackage'], hub_name: Optional[str] = None ) -> _ModelInfo: """ Resolve model information from various input types. Args: - base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN + base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN or BaseTrainer object with a completed job hub_name: Optional hub name for JumpStart models (defaults to SageMakerPublicHub) Returns: @@ -88,6 +90,17 @@ def resolve_model_info( return self._resolve_model_package_arn(base_model) else: return self._resolve_jumpstart_model(base_model, hub_name or self.DEFAULT_HUB_NAME) + # Handle BaseTrainer type + elif isinstance(base_model, BaseTrainer): + if hasattr(base_model, '_latest_training_job') and hasattr(base_model._latest_training_job, + 'output_model_package_arn'): + arn = base_model._latest_training_job.output_model_package_arn + if not isinstance(arn, Unassigned): + return self._resolve_model_package_arn(arn) + else: + raise ValueError("BaseTrainer must have completed training job to be used for evaluation") + else: + raise ValueError("BaseTrainer must have completed training job to be used for evaluation") else: # Not a string, so assume it's a ModelPackage object # Check if it has the expected attributes of a ModelPackage diff --git a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py index 620b7ffe34..6a87fa96eb 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py @@ -13,12 +13,13 @@ from pydantic import BaseModel, validator -from sagemaker.core.resources import ModelPackageGroup +from sagemaker.core.resources import ModelPackageGroup, ModelPackage from sagemaker.core.shapes import VpcConfig if TYPE_CHECKING: from sagemaker.core.helper.session_helper import Session +from sagemaker.train.base_trainer import BaseTrainer # Module-level logger _logger = logging.getLogger(__name__) @@ -53,6 +54,7 @@ class BaseEvaluator(BaseModel): - JumpStart model ID (str): e.g., 'llama3-2-1b-instruct' - ModelPackage object: A fine-tuned model package - ModelPackage ARN (str): e.g., 'arn:aws:sagemaker:region:account:model-package/name/version' + - BaseTrainer object: A completed training job (i.e., it must have _latest_training_job with output_model_package_arn populated) base_eval_name (Optional[str]): Optional base name for evaluation jobs. This name is used as the PipelineExecutionDisplayName when creating the SageMaker pipeline execution. The actual display name will be "{base_eval_name}-{timestamp}". This parameter can @@ -86,7 +88,7 @@ class BaseEvaluator(BaseModel): region: Optional[str] = None sagemaker_session: Optional[Any] = None - model: Union[str, Any] + model: Union[str, BaseTrainer, ModelPackage] base_eval_name: Optional[str] = None s3_output_path: str mlflow_resource_arn: Optional[str] = None @@ -278,7 +280,7 @@ def _validate_mlflow_arn_format(cls, v: Optional[str]) -> Optional[str]: return v @validator('model') - def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any]: + def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: dict) -> Union[str, Any]: """Resolve model information from various input types. This validator uses the common model resolution utility to extract: @@ -289,7 +291,7 @@ def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any The resolved information is stored in private attributes for use by subclasses. Args: - v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, or ARN). + v (Union[str, BaseTrainer, ModelPackage]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer). values (dict): Dictionary of already-validated fields. Returns: diff --git a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py index 4ca685b811..5d37e53f8c 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py @@ -300,18 +300,10 @@ class BenchMarkEvaluator(BaseEvaluator): """ benchmark: _Benchmark - dataset: Union[str, Any] # Required field, must come before optional fields subtasks: Optional[Union[str, List[str]]] = None evaluate_base_model: bool = True _hyperparameters: Optional[Any] = None - - @validator('dataset', pre=True) - def _resolve_dataset(cls, v): - """Resolve dataset to string (S3 URI or ARN) and validate format. - - Uses BaseEvaluator's common validation logic to avoid code duplication. - """ - return BaseEvaluator._validate_and_resolve_dataset(v) + @validator('benchmark') def _validate_benchmark_model_compatibility(cls, v, values): diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index 1bf5c02813..ebadc6bfda 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -286,7 +286,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati except TimeoutExceededError as e: logger.error("Error: %s", e) - self.latest_training_job = training_job + self._latest_training_job = training_job return training_job def _process_hyperparameters(self): diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index f00c7aac36..b28c9d865c 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -274,5 +274,5 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, except TimeoutExceededError as e: logger.error("Error: %s", e) - self.latest_training_job = training_job + self._latest_training_job = training_job return training_job diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 57d2c52a06..b2688dce5d 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -268,7 +268,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati except TimeoutExceededError as e: logger.error("Error: %s", e) - self.latest_training_job = training_job + self._latest_training_job = training_job return training_job diff --git a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py index d0cc5990a8..31a827e3f0 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py @@ -24,6 +24,8 @@ _ModelResolver, _resolve_base_model, ) +from sagemaker.train.base_trainer import BaseTrainer +from sagemaker.core.utils.utils import Unassigned class TestModelType: @@ -557,3 +559,74 @@ def test_resolve_base_model_with_hub_name(self, mock_resolver_class): _resolve_base_model("test-model", hub_name="CustomHub") mock_resolver.resolve_model_info.assert_called_once_with("test-model", "CustomHub") + + +class TestBaseTrainerHandling: + """Tests for BaseTrainer model handling in _resolve_base_model.""" + + def test_base_trainer_with_valid_training_job(self): + """Test BaseTrainer with valid completed training job.""" + # Create concrete BaseTrainer subclass for testing + class TestTrainer(BaseTrainer): + def train(self, input_data_config, wait=True, logs=True): + pass + + mock_trainer = TestTrainer() + mock_training_job = MagicMock() + mock_training_job.output_model_package_arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-package/1" + mock_trainer._latest_training_job = mock_training_job + + with patch('sagemaker.train.common_utils.model_resolution._ModelResolver._resolve_model_package_arn') as mock_resolve_arn: + mock_resolve_arn.return_value = MagicMock() + + result = _resolve_base_model(mock_trainer) + + # Verify model package ARN resolution was called + mock_resolve_arn.assert_called_once_with( + "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-package/1" + ) + + def test_base_trainer_with_unassigned_arn(self): + """Test BaseTrainer with Unassigned output_model_package_arn raises error.""" + # Create concrete BaseTrainer subclass for testing + class TestTrainer(BaseTrainer): + def train(self, input_data_config, wait=True, logs=True): + pass + + mock_trainer = TestTrainer() + mock_training_job = MagicMock() + mock_training_job.output_model_package_arn = Unassigned() + mock_trainer._latest_training_job = mock_training_job + + with pytest.raises(ValueError, match="BaseTrainer must have completed training job"): + _resolve_base_model(mock_trainer) + + def test_base_trainer_without_training_job(self): + """Test BaseTrainer without _latest_training_job raises error.""" + # Create concrete BaseTrainer subclass for testing + class TestTrainer(BaseTrainer): + def train(self, input_data_config, wait=True, logs=True): + pass + + mock_trainer = TestTrainer() + # Don't set _latest_training_job attribute at all + + with pytest.raises(ValueError, match="BaseTrainer must have completed training job"): + _resolve_base_model(mock_trainer) + + def test_base_trainer_without_output_model_package_arn_attribute(self): + """Test BaseTrainer with training job but missing output_model_package_arn attribute.""" + # Create concrete BaseTrainer subclass for testing + class TestTrainer(BaseTrainer): + def train(self, input_data_config, wait=True, logs=True): + pass + + # Create a simple object without output_model_package_arn + class TrainingJobWithoutArn: + pass + + mock_trainer = TestTrainer() + mock_trainer._latest_training_job = TrainingJobWithoutArn() + + with pytest.raises(ValueError, match="BaseTrainer must have completed training job"): + _resolve_base_model(mock_trainer) diff --git a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py index 86c09489a0..c9b2e0a255 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py @@ -20,6 +20,8 @@ from sagemaker.core.shapes import VpcConfig from sagemaker.core.resources import ModelPackageGroup, Artifact from sagemaker.core.shapes import ArtifactSource, ArtifactSourceType +from sagemaker.core.utils.utils import Unassigned +from sagemaker.train.base_trainer import BaseTrainer from sagemaker.train.evaluate.base_evaluator import BaseEvaluator diff --git a/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py index e9c74e3f2b..858bb12d32 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_benchmark_evaluator.py @@ -121,7 +121,7 @@ def test_benchmark_evaluator_initialization_minimal(mock_artifact, mock_resolve) evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -130,7 +130,6 @@ def test_benchmark_evaluator_initialization_minimal(mock_artifact, mock_resolve) assert evaluator.benchmark == _Benchmark.MMLU assert evaluator.model == DEFAULT_MODEL - assert evaluator.dataset == DEFAULT_DATASET assert evaluator.evaluate_base_model is True assert evaluator.subtasks == "ALL" @@ -158,7 +157,7 @@ def test_benchmark_evaluator_subtask_defaults_to_all(mock_artifact, mock_resolve evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -188,7 +187,7 @@ def test_benchmark_evaluator_subtask_validation_invalid(mock_artifact, mock_reso benchmark=_Benchmark.MMLU, subtasks=["invalid_subtask"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -216,7 +215,7 @@ def test_benchmark_evaluator_no_subtask_for_unsupported_benchmark(mock_artifact, benchmark=_Benchmark.GPQA, subtasks="some_subtask", model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -250,14 +249,13 @@ def test_benchmark_evaluator_dataset_resolution_from_object(mock_artifact, mock_ evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=mock_dataset, s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, sagemaker_session=mock_session, ) - assert evaluator.dataset == mock_dataset.arn + # Dataset field is commented out, so no assertion needed @patch('sagemaker.train.common_utils.model_resolution._resolve_base_model') @@ -284,7 +282,7 @@ def test_benchmark_evaluator_evaluate_method_exists(mock_artifact, mock_resolve) benchmark=_Benchmark.MMLU, subtasks=["abstract_algebra"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -327,7 +325,7 @@ def test_benchmark_evaluator_evaluate_invalid_subtask_override(mock_artifact, mo evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -378,7 +376,7 @@ def test_benchmark_evaluator_missing_required_fields(): BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, sagemaker_session=mock_session, ) @@ -408,7 +406,7 @@ def test_benchmark_evaluator_resolve_subtask_for_evaluation(mock_artifact, mock_ benchmark=_Benchmark.MMLU, subtasks="abstract_algebra", # Use a specific subtask instead of "ALL" model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -458,7 +456,7 @@ def test_benchmark_evaluator_hyperparameters_property(mock_artifact, mock_resolv evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -512,7 +510,7 @@ def test_benchmark_evaluator_get_benchmark_template_additions(mock_artifact, moc benchmark=_Benchmark.MMLU, subtasks=["abstract_algebra"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -559,7 +557,7 @@ def test_benchmark_evaluator_mmmu_nova_validation(mock_artifact, mock_resolve, m BenchMarkEvaluator( benchmark=_Benchmark.MMMU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -596,7 +594,7 @@ def test_benchmark_evaluator_llm_judge_nova_validation(mock_artifact, mock_resol BenchMarkEvaluator( benchmark=_Benchmark.LLM_JUDGE, model="nova-pro", - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -629,7 +627,7 @@ def test_benchmark_evaluator_subtask_list_validation(mock_artifact, mock_resolve benchmark=_Benchmark.MMLU, subtasks=["abstract_algebra", "anatomy"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -643,7 +641,7 @@ def test_benchmark_evaluator_subtask_list_validation(mock_artifact, mock_resolve benchmark=_Benchmark.MMLU, subtasks=[], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -675,7 +673,7 @@ def test_benchmark_evaluator_resolve_subtask_list(mock_artifact, mock_resolve): benchmark=_Benchmark.MMLU, subtasks=["abstract_algebra", "anatomy"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -729,7 +727,7 @@ def test_benchmark_evaluator_template_additions_with_list_subtasks(mock_artifact benchmark=_Benchmark.MMLU, subtasks=["abstract_algebra", "anatomy"], model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -761,7 +759,7 @@ def test_benchmark_evaluator_with_subtask_list(mock_resolve): evaluator = BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.MMLU, subtasks=['abstract_algebra', 'anatomy'], s3_output_path=DEFAULT_S3_OUTPUT, @@ -788,7 +786,7 @@ def test_benchmark_evaluator_with_subtask_string(mock_resolve): evaluator = BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.MMLU, subtasks='abstract_algebra', s3_output_path=DEFAULT_S3_OUTPUT, @@ -817,7 +815,7 @@ def test_benchmark_evaluator_invalid_subtask(mock_resolve): with pytest.raises(ValidationError, match="Invalid subtask"): BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.MMLU, subtasks=['invalid_subtask'], s3_output_path=DEFAULT_S3_OUTPUT, @@ -843,7 +841,7 @@ def test_benchmark_evaluator_no_subtask_available(mock_resolve): # IFEVAL doesn't support subtasks evaluator = BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.IFEVAL, s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, @@ -874,7 +872,7 @@ def test_benchmark_evaluator_with_networking(mock_resolve): evaluator = BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.MMLU, s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, @@ -903,7 +901,7 @@ def test_benchmark_evaluator_with_kms_key(mock_resolve): evaluator = BenchMarkEvaluator( model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + benchmark=_Benchmark.MMLU, s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, @@ -951,7 +949,7 @@ def test_benchmark_evaluator_uses_metric_key_for_nova(mock_artifact, mock_resolv evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model="nova-pro", - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, @@ -1001,7 +999,7 @@ def test_benchmark_evaluator_uses_evaluation_metric_key_for_non_nova(mock_artifa evaluator = BenchMarkEvaluator( benchmark=_Benchmark.MMLU, model=DEFAULT_MODEL, - dataset=DEFAULT_DATASET, + s3_output_path=DEFAULT_S3_OUTPUT, mlflow_resource_arn=DEFAULT_MLFLOW_ARN, model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,