Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from dataclasses import dataclass
from enum import Enum
import re
from sagemaker.train.base_trainer import BaseTrainer
from sagemaker.core.utils.utils import Unassigned


class _ModelType(Enum):
Expand Down Expand Up @@ -65,14 +67,14 @@ def __init__(self, sagemaker_session=None):

def resolve_model_info(
self,
base_model: Union[str, 'ModelPackage'],
base_model: Union[str, BaseTrainer, 'ModelPackage'],
hub_name: Optional[str] = None
) -> _ModelInfo:
"""
Resolve model information from various input types.

Args:
base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN
base_model: Either a JumpStart model ID (str) or ModelPackage object/ARN or BaseTrainer object with a completed job
hub_name: Optional hub name for JumpStart models (defaults to SageMakerPublicHub)

Returns:
Expand All @@ -88,6 +90,17 @@ def resolve_model_info(
return self._resolve_model_package_arn(base_model)
else:
return self._resolve_jumpstart_model(base_model, hub_name or self.DEFAULT_HUB_NAME)
# Handle BaseTrainer type
elif isinstance(base_model, BaseTrainer):
if hasattr(base_model, '_latest_training_job') and hasattr(base_model._latest_training_job,
'output_model_package_arn'):
arn = base_model._latest_training_job.output_model_package_arn
if not isinstance(arn, Unassigned):
return self._resolve_model_package_arn(arn)
else:
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
else:
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
else:
# Not a string, so assume it's a ModelPackage object
# Check if it has the expected attributes of a ModelPackage
Expand Down
10 changes: 6 additions & 4 deletions sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@

from pydantic import BaseModel, validator

from sagemaker.core.resources import ModelPackageGroup
from sagemaker.core.resources import ModelPackageGroup, ModelPackage
from sagemaker.core.shapes import VpcConfig

if TYPE_CHECKING:
from sagemaker.core.helper.session_helper import Session

from sagemaker.train.base_trainer import BaseTrainer
# Module-level logger
_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,6 +54,7 @@ class BaseEvaluator(BaseModel):
- JumpStart model ID (str): e.g., 'llama3-2-1b-instruct'
- ModelPackage object: A fine-tuned model package
- ModelPackage ARN (str): e.g., 'arn:aws:sagemaker:region:account:model-package/name/version'
- BaseTrainer object: A completed training job (i.e., it must have _latest_training_job with output_model_package_arn populated)
base_eval_name (Optional[str]): Optional base name for evaluation jobs. This name is used
as the PipelineExecutionDisplayName when creating the SageMaker pipeline execution.
The actual display name will be "{base_eval_name}-{timestamp}". This parameter can
Expand Down Expand Up @@ -86,7 +88,7 @@ class BaseEvaluator(BaseModel):

region: Optional[str] = None
sagemaker_session: Optional[Any] = None
model: Union[str, Any]
model: Union[str, BaseTrainer, ModelPackage]
base_eval_name: Optional[str] = None
s3_output_path: str
mlflow_resource_arn: Optional[str] = None
Expand Down Expand Up @@ -278,7 +280,7 @@ def _validate_mlflow_arn_format(cls, v: Optional[str]) -> Optional[str]:
return v

@validator('model')
def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any]:
def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: dict) -> Union[str, Any]:
"""Resolve model information from various input types.

This validator uses the common model resolution utility to extract:
Expand All @@ -289,7 +291,7 @@ def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any
The resolved information is stored in private attributes for use by subclasses.

Args:
v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, or ARN).
v (Union[str, BaseTrainer, ModelPackage]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer).
values (dict): Dictionary of already-validated fields.

Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,18 +300,10 @@ class BenchMarkEvaluator(BaseEvaluator):
"""

benchmark: _Benchmark
dataset: Union[str, Any] # Required field, must come before optional fields
subtasks: Optional[Union[str, List[str]]] = None
evaluate_base_model: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets make evaluate_base_model false for all evaluations.

_hyperparameters: Optional[Any] = None

@validator('dataset', pre=True)
def _resolve_dataset(cls, v):
"""Resolve dataset to string (S3 URI or ARN) and validate format.

Uses BaseEvaluator's common validation logic to avoid code duplication.
"""
return BaseEvaluator._validate_and_resolve_dataset(v)


@validator('benchmark')
def _validate_benchmark_model_compatibility(cls, v, values):
Expand Down
2 changes: 1 addition & 1 deletion sagemaker-train/src/sagemaker/train/rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
except TimeoutExceededError as e:
logger.error("Error: %s", e)

self.latest_training_job = training_job
self._latest_training_job = training_job
return training_job

def _process_hyperparameters(self):
Expand Down
2 changes: 1 addition & 1 deletion sagemaker-train/src/sagemaker/train/rlvr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,5 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
except TimeoutExceededError as e:
logger.error("Error: %s", e)

self.latest_training_job = training_job
self._latest_training_job = training_job
return training_job
2 changes: 1 addition & 1 deletion sagemaker-train/src/sagemaker/train/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
except TimeoutExceededError as e:
logger.error("Error: %s", e)

self.latest_training_job = training_job
self._latest_training_job = training_job
return training_job


Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
_ModelResolver,
_resolve_base_model,
)
from sagemaker.train.base_trainer import BaseTrainer
from sagemaker.core.utils.utils import Unassigned


class TestModelType:
Expand Down Expand Up @@ -557,3 +559,74 @@ def test_resolve_base_model_with_hub_name(self, mock_resolver_class):
_resolve_base_model("test-model", hub_name="CustomHub")

mock_resolver.resolve_model_info.assert_called_once_with("test-model", "CustomHub")


class TestBaseTrainerHandling:
"""Tests for BaseTrainer model handling in _resolve_base_model."""

def test_base_trainer_with_valid_training_job(self):
"""Test BaseTrainer with valid completed training job."""
# Create concrete BaseTrainer subclass for testing
class TestTrainer(BaseTrainer):
def train(self, input_data_config, wait=True, logs=True):
pass

mock_trainer = TestTrainer()
mock_training_job = MagicMock()
mock_training_job.output_model_package_arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-package/1"
mock_trainer._latest_training_job = mock_training_job

with patch('sagemaker.train.common_utils.model_resolution._ModelResolver._resolve_model_package_arn') as mock_resolve_arn:
mock_resolve_arn.return_value = MagicMock()

result = _resolve_base_model(mock_trainer)

# Verify model package ARN resolution was called
mock_resolve_arn.assert_called_once_with(
"arn:aws:sagemaker:us-west-2:123456789012:model-package/my-package/1"
)

def test_base_trainer_with_unassigned_arn(self):
"""Test BaseTrainer with Unassigned output_model_package_arn raises error."""
# Create concrete BaseTrainer subclass for testing
class TestTrainer(BaseTrainer):
def train(self, input_data_config, wait=True, logs=True):
pass

mock_trainer = TestTrainer()
mock_training_job = MagicMock()
mock_training_job.output_model_package_arn = Unassigned()
mock_trainer._latest_training_job = mock_training_job

with pytest.raises(ValueError, match="BaseTrainer must have completed training job"):
_resolve_base_model(mock_trainer)

def test_base_trainer_without_training_job(self):
"""Test BaseTrainer without _latest_training_job raises error."""
# Create concrete BaseTrainer subclass for testing
class TestTrainer(BaseTrainer):
def train(self, input_data_config, wait=True, logs=True):
pass

mock_trainer = TestTrainer()
# Don't set _latest_training_job attribute at all

with pytest.raises(ValueError, match="BaseTrainer must have completed training job"):
_resolve_base_model(mock_trainer)

def test_base_trainer_without_output_model_package_arn_attribute(self):
"""Test BaseTrainer with training job but missing output_model_package_arn attribute."""
# Create concrete BaseTrainer subclass for testing
class TestTrainer(BaseTrainer):
def train(self, input_data_config, wait=True, logs=True):
pass

# Create a simple object without output_model_package_arn
class TrainingJobWithoutArn:
pass

mock_trainer = TestTrainer()
mock_trainer._latest_training_job = TrainingJobWithoutArn()

with pytest.raises(ValueError, match="BaseTrainer must have completed training job"):
_resolve_base_model(mock_trainer)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from sagemaker.core.shapes import VpcConfig
from sagemaker.core.resources import ModelPackageGroup, Artifact
from sagemaker.core.shapes import ArtifactSource, ArtifactSourceType
from sagemaker.core.utils.utils import Unassigned
from sagemaker.train.base_trainer import BaseTrainer

from sagemaker.train.evaluate.base_evaluator import BaseEvaluator

Expand Down
Loading
Loading