Skip to content

Commit 471a246

Browse files
committed
add eval custom lambda arn to hyperparameter
For commit: aws/sagemaker-python-sdk-staging@bcd5348
1 parent e043098 commit 471a246

File tree

2 files changed

+78
-35
lines changed

2 files changed

+78
-35
lines changed

sagemaker-train/src/sagemaker/train/sm_recipes/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,19 @@ def _get_args_from_recipe(
322322
args["source_code"].requirements = os.path.basename(requirements)
323323

324324
# Update args with compute and hyperparameters
325+
hyperparameters = {"config-path": ".", "config-name": "recipe.yaml"}
326+
327+
# Handle eval custom lambda configuration
328+
if recipe.get("evaluation", {}):
329+
processor = recipe.get("processor", {})
330+
lambda_arn = processor.get("lambda_arn", "")
331+
if lambda_arn:
332+
hyperparameters["lambda_arn"] = lambda_arn
333+
325334
args.update(
326335
{
327336
"compute": compute,
328-
"hyperparameters": {"config-path": ".", "config-name": "recipe.yaml"},
337+
"hyperparameters": hyperparameters,
329338
}
330339
)
331340

sagemaker-train/tests/unit/train/sm_recipes/test_utils.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -201,38 +201,72 @@ def test_get_args_from_recipe_compute(
201201
)
202202
assert mock_gpu_args.call_count == 0
203203
assert mock_trainium_args.call_count == 0
204-
assert args is None
205-
206-
@pytest.mark.parametrize(
207-
"test_case",
208-
[
209-
{
210-
"model_type": "llama_v3",
211-
"script": "llama_pretrain.py",
212-
"model_base_name": "llama_v3",
213-
},
214-
{
215-
"model_type": "mistral",
216-
"script": "mistral_pretrain.py",
217-
"model_base_name": "mistral",
218-
},
219-
{
220-
"model_type": "deepseek_llamav3",
221-
"script": "deepseek_pretrain.py",
222-
"model_base_name": "deepseek",
223-
},
224-
{
225-
"model_type": "deepseek_qwenv2",
226-
"script": "deepseek_pretrain.py",
227-
"model_base_name": "deepseek",
228-
},
229-
],
204+
205+
@pytest.mark.parametrize(
206+
"test_case",
207+
[
208+
{
209+
"model_type": "llama_v3",
210+
"model_base_name": "llama",
211+
"script": "llama_pretrain.py",
212+
},
213+
{
214+
"model_type": "mistral",
215+
"model_base_name": "mistral",
216+
"script": "mistral_pretrain.py",
217+
},
218+
{
219+
"model_type": "deepseek_llamav3",
220+
"model_base_name": "deepseek",
221+
"script": "deepseek_pretrain.py",
222+
},
223+
{
224+
"model_type": "deepseek_qwenv2",
225+
"model_base_name": "deepseek",
226+
"script": "deepseek_pretrain.py",
227+
},
228+
],
229+
)
230+
def test_get_trainining_recipe_gpu_model_name_and_script(test_case):
231+
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(
232+
test_case["model_type"]
230233
)
231-
def test_get_trainining_recipe_gpu_model_name_and_script(test_case):
232-
model_type = test_case["model_type"]
233-
script = test_case["script"]
234-
model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(
235-
model_type, script
236-
)
237-
assert model_base_name == test_case["model_base_name"]
238-
assert script == test_case["script"]
234+
assert model_base_name == test_case["model_base_name"]
235+
assert script == test_case["script"]
236+
237+
238+
def test_get_args_from_recipe_with_evaluation(temporary_recipe):
239+
import tempfile
240+
import os
241+
from sagemaker.train.configs import SourceCode
242+
243+
# Create a recipe with evaluation config
244+
recipe_data = {
245+
"trainer": {"num_nodes": 1},
246+
"model": {"model_type": "llama_v3"},
247+
"evaluation": {"task": "gen_qa"},
248+
"processor": {"lambda_arn": "arn:aws:lambda:us-east-1:123456789012:function:MyFunc"},
249+
}
250+
251+
with NamedTemporaryFile(suffix=".yaml", delete=False) as f:
252+
with open(f.name, "w") as file:
253+
yaml.dump(recipe_data, file)
254+
recipe_path = f.name
255+
256+
try:
257+
compute = Compute(instance_type="ml.p4d.24xlarge", instance_count=1)
258+
with patch("sagemaker.train.sm_recipes.utils._configure_gpu_args") as mock_gpu:
259+
mock_source = SourceCode()
260+
mock_source.source_dir = "/tmp/test"
261+
mock_gpu.return_value = {"source_code": mock_source, "hyperparameters": {}}
262+
with patch("sagemaker.train.sm_recipes.utils.OmegaConf.save"):
263+
args, _ = _get_args_from_recipe(
264+
training_recipe=recipe_path,
265+
compute=compute,
266+
region_name="us-west-2",
267+
recipe_overrides=None,
268+
requirements=None,
269+
)
270+
assert args["hyperparameters"]["lambda_arn"] == "arn:aws:lambda:us-east-1:123456789012:function:MyFunc"
271+
finally:
272+
os.unlink(recipe_path)

0 commit comments

Comments
 (0)