Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

from sagemaker.core.resources import ModelPackageGroup
from sagemaker.core.shapes import VpcConfig
from sagemaker.core.utils.utils import Unassigned

if TYPE_CHECKING:
from sagemaker.core.helper.session_helper import Session
from sagemaker.train.base_trainer import BaseTrainer

# Module-level logger
_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -278,7 +280,7 @@ def _validate_mlflow_arn_format(cls, v: Optional[str]) -> Optional[str]:
return v

@validator('model')
def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any]:
def _resolve_model_info(cls, v: Union[str, "BaseTrainer", Any], values: dict) -> Union[str, Any]:
"""Resolve model information from various input types.

This validator uses the common model resolution utility to extract:
Expand All @@ -289,7 +291,7 @@ def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any
The resolved information is stored in private attributes for use by subclasses.

Args:
v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, or ARN).
v (Union[str, Any]): Model identifier (JumpStart ID, ModelPackage, ARN, or BaseTrainer).
values (dict): Dictionary of already-validated fields.

Returns:
Expand All @@ -302,12 +304,25 @@ def _resolve_model_info(cls, v: Union[str, Any], values: dict) -> Union[str, Any
import os

try:
# Handle BaseTrainer type
if hasattr(v, '__class__') and v.__class__.__name__ == 'BaseTrainer' or hasattr(v, '_latest_training_job'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the conditioning of v.__class__.__name__ == 'BaseTrainer' will I able to supply SFTTrainer?

Copy link
Contributor Author

@rsareddy0329 rsareddy0329 Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous jobs were submitted fine, so might have worked.
I moved the code to model_resolution.py to keep it central in this method _resolve_base_model and updated it to isinstance to be sure.

if hasattr(v._latest_training_job, 'output_model_package_arn'):
arn = v._latest_training_job.output_model_package_arn
if not isinstance(arn, Unassigned):
model_to_resolve = arn
else:
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
else:
raise ValueError("BaseTrainer must have completed training job to be used for evaluation")
else:
model_to_resolve = v

# Get the session for resolution (may not be created yet due to validator order)
session = values.get('sagemaker_session')

# Resolve model information
model_info = _resolve_base_model(
base_model=v,
base_model=model_to_resolve,
sagemaker_session=session
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,18 +300,10 @@ class BenchMarkEvaluator(BaseEvaluator):
"""

benchmark: _Benchmark
dataset: Union[str, Any] # Required field, must come before optional fields
subtasks: Optional[Union[str, List[str]]] = None
evaluate_base_model: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets make evaluate_base_model false for all evaluations.

_hyperparameters: Optional[Any] = None

@validator('dataset', pre=True)
def _resolve_dataset(cls, v):
"""Resolve dataset to string (S3 URI or ARN) and validate format.

Uses BaseEvaluator's common validation logic to avoid code duplication.
"""
return BaseEvaluator._validate_and_resolve_dataset(v)

@validator('benchmark')
def _validate_benchmark_model_compatibility(cls, v, values):
Expand Down
85 changes: 85 additions & 0 deletions sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from sagemaker.core.shapes import VpcConfig
from sagemaker.core.resources import ModelPackageGroup, Artifact
from sagemaker.core.shapes import ArtifactSource, ArtifactSourceType
from sagemaker.core.utils.utils import Unassigned
from sagemaker.train.base_trainer import BaseTrainer

from sagemaker.train.evaluate.base_evaluator import BaseEvaluator

Expand Down Expand Up @@ -1291,3 +1293,86 @@ def test_with_all_optional_params(self, mock_resolve, mock_session, mock_model_i
assert evaluator.networking == vpc_config
assert evaluator.kms_key_id == "arn:aws:kms:us-west-2:123456789012:key/12345"
assert evaluator.region == DEFAULT_REGION


class TestBaseTrainerHandling:
"""Tests for BaseTrainer model handling."""

@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
def test_base_trainer_with_valid_training_job(self, mock_resolve, mock_session, mock_model_info_with_package):
"""Test BaseTrainer with valid completed training job."""
mock_resolve.return_value = mock_model_info_with_package

# Create mock BaseTrainer with completed training job
mock_trainer = MagicMock(spec=BaseTrainer)
mock_training_job = MagicMock()
mock_training_job.output_model_package_arn = DEFAULT_MODEL_PACKAGE_ARN
mock_trainer._latest_training_job = mock_training_job

evaluator = BaseEvaluator(
model=mock_trainer,
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
sagemaker_session=mock_session,
)

# Verify model resolution was called with the training job's model package ARN
mock_resolve.assert_called_once_with(
base_model=DEFAULT_MODEL_PACKAGE_ARN,
sagemaker_session=mock_session
)
assert evaluator.model == mock_trainer

@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
def test_base_trainer_with_unassigned_arn(self, mock_resolve, mock_session):
"""Test BaseTrainer with Unassigned output_model_package_arn raises error."""
# Create mock BaseTrainer with Unassigned ARN
mock_trainer = MagicMock(spec=BaseTrainer)
mock_training_job = MagicMock()
mock_training_job.output_model_package_arn = Unassigned()
mock_trainer._latest_training_job = mock_training_job

with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"):
BaseEvaluator(
model=mock_trainer,
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
sagemaker_session=mock_session,
)

@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
def test_base_trainer_without_training_job(self, mock_resolve, mock_session):
"""Test BaseTrainer without _latest_training_job falls through to normal processing."""
# Create mock BaseTrainer without _latest_training_job attribute
mock_trainer = MagicMock()
mock_trainer.__class__.__name__ = 'BaseTrainer'
# Don't set _latest_training_job attribute at all

# This should fail during model resolution, not in BaseTrainer handling
with pytest.raises(ValidationError, match="Failed to resolve model"):
BaseEvaluator(
model=mock_trainer,
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
sagemaker_session=mock_session,
)

def test_base_trainer_without_output_model_package_arn_attribute(self, mock_session):
"""Test BaseTrainer with training job but missing output_model_package_arn attribute."""

# Create a custom class that doesn't have output_model_package_arn
class MockTrainingJobWithoutArn:
pass

# Create mock BaseTrainer with _latest_training_job but no output_model_package_arn
mock_trainer = MagicMock()
mock_trainer.__class__.__name__ = 'BaseTrainer'
mock_trainer._latest_training_job = MockTrainingJobWithoutArn()

with pytest.raises(ValidationError, match="BaseTrainer must have completed training job"):
BaseEvaluator(
model=mock_trainer,
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
sagemaker_session=mock_session,
)
Loading
Loading