Skip to content

Commit c4ca393

Browse files
committed
Add support for MetricDefinitions in ModelTrainer
For commit: 0215512
1 parent 5878547 commit c4ca393

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed

sagemaker-core/src/sagemaker/core/modules/configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
InstanceGroup,
4646
TensorBoardOutputConfig,
4747
CheckpointConfig,
48+
MetricDefinition,
4849
)
4950

5051

@@ -70,6 +71,7 @@
7071
"Compute",
7172
"Networking",
7273
"InputData",
74+
"MetricDefinition",
7375
]
7476

7577
from sagemaker.core.modules.utils import convert_unassigned_to_none

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
RemoteDebugConfig,
6565
SessionChainingConfig,
6666
InputData,
67+
MetricDefinition,
6768
)
6869

6970
from sagemaker.train.distributed import Torchrun, DistributedConfig
@@ -244,6 +245,7 @@ class ModelTrainer(BaseModel):
244245
_infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None)
245246
_session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None)
246247
_remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
248+
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)
247249

248250
# Private Attributes for Recipes
249251
_temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
@@ -654,6 +656,7 @@ def train(
654656
training_image_config=self.training_image_config,
655657
container_entrypoint=container_entrypoint,
656658
container_arguments=container_arguments,
659+
metric_definitions=self._metric_definitions,
657660
)
658661

659662
resource_config = self.compute._to_resource_config()
@@ -1496,3 +1499,29 @@ def with_checkpoint_config(
14961499
"""
14971500
self.checkpoint_config = checkpoint_config or configs.CheckpointConfig()
14981501
return self
1502+
1503+
def with_metric_definitions(
1504+
self,
1505+
metric_definitions: List[MetricDefinition]
1506+
) -> "ModelTrainer": # noqa: D412
1507+
"""Set the metric definitions for the training job.
1508+
Example:
1509+
.. code:: python
1510+
from sagemaker.modules.train import ModelTrainer
1511+
from sagemaker.modules.configs import MetricDefinition
1512+
metric_definitions = [
1513+
MetricDefinition(
1514+
name="loss",
1515+
regex="Loss: (.*?)",
1516+
)
1517+
]
1518+
model_trainer = ModelTrainer(
1519+
...
1520+
).with_metric_definitions(metric_definitions)
1521+
Args:
1522+
metric_definitions (List[MetricDefinition]):
1523+
The metric definitions for the training job.
1524+
"""
1525+
self._metric_definitions = metric_definitions
1526+
1527+
return self

sagemaker-train/tests/unit/train/test_model_trainer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
FileSystemDataSource,
6565
Channel,
6666
DataSource,
67+
MetricDefinition,
6768
)
6869
from sagemaker.train.distributed import Torchrun, SMP, MPI
6970
from sagemaker.train.sm_recipes.utils import _load_recipes_cfg
@@ -830,6 +831,7 @@ def mock_upload_data(path, bucket, key_prefix):
830831
training_input_mode=training_input_mode,
831832
training_image=training_image,
832833
algorithm_name=None,
834+
metric_definitions=None,
833835
container_entrypoint=DEFAULT_ENTRYPOINT,
834836
container_arguments=DEFAULT_ARGUMENTS,
835837
training_image_config=training_image_config,
@@ -1321,3 +1323,27 @@ def test_input_merge(mock_training_job, modules_session):
13211323
input_mode="File",
13221324
),
13231325
]
1326+
1327+
@patch("sagemaker.train.model_trainer.TrainingJob")
1328+
def test_metric_definitions(mock_training_job, modules_session):
1329+
image_uri = DEFAULT_IMAGE
1330+
role = DEFAULT_ROLE
1331+
metric_definitions = [
1332+
MetricDefinition(
1333+
name="loss",
1334+
regex="Loss: (.*?);",
1335+
)
1336+
]
1337+
model_trainer = ModelTrainer(
1338+
training_image=image_uri, sagemaker_session=modules_session, role=role
1339+
).with_metric_definitions(metric_definitions)
1340+
1341+
with patch("sagemaker.train.model_trainer.Session.upload_data") as mock_upload_data:
1342+
mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix"
1343+
model_trainer.train()
1344+
mock_training_job.create.assert_called_once()
1345+
1346+
assert (
1347+
mock_training_job.create.call_args.kwargs["algorithm_specification"].metric_definitions
1348+
== metric_definitions
1349+
)

0 commit comments

Comments
 (0)