@@ -201,38 +201,72 @@ def test_get_args_from_recipe_compute(
201201 )
202202 assert mock_gpu_args .call_count == 0
203203 assert mock_trainium_args .call_count == 0
204- assert args is None
205-
206- @pytest .mark .parametrize (
207- "test_case" ,
208- [
209- {
210- "model_type" : "llama_v3" ,
211- "script" : "llama_pretrain.py" ,
212- "model_base_name" : "llama_v3" ,
213- },
214- {
215- "model_type" : "mistral" ,
216- "script" : "mistral_pretrain.py" ,
217- "model_base_name" : "mistral" ,
218- },
219- {
220- "model_type" : "deepseek_llamav3" ,
221- "script" : "deepseek_pretrain.py" ,
222- "model_base_name" : "deepseek" ,
223- },
224- {
225- "model_type" : "deepseek_qwenv2" ,
226- "script" : "deepseek_pretrain.py" ,
227- "model_base_name" : "deepseek" ,
228- },
229- ],
204+
205+ @pytest .mark .parametrize (
206+ "test_case" ,
207+ [
208+ {
209+ "model_type" : "llama_v3" ,
210+ "model_base_name" : "llama" ,
211+ "script" : "llama_pretrain.py" ,
212+ },
213+ {
214+ "model_type" : "mistral" ,
215+ "model_base_name" : "mistral" ,
216+ "script" : "mistral_pretrain.py" ,
217+ },
218+ {
219+ "model_type" : "deepseek_llamav3" ,
220+ "model_base_name" : "deepseek" ,
221+ "script" : "deepseek_pretrain.py" ,
222+ },
223+ {
224+ "model_type" : "deepseek_qwenv2" ,
225+ "model_base_name" : "deepseek" ,
226+ "script" : "deepseek_pretrain.py" ,
227+ },
228+ ],
229+ )
230+ def test_get_trainining_recipe_gpu_model_name_and_script (test_case ):
231+ model_base_name , script = _get_trainining_recipe_gpu_model_name_and_script (
232+ test_case ["model_type" ]
230233 )
231- def test_get_trainining_recipe_gpu_model_name_and_script (test_case ):
232- model_type = test_case ["model_type" ]
233- script = test_case ["script" ]
234- model_base_name , script = _get_trainining_recipe_gpu_model_name_and_script (
235- model_type , script
236- )
237- assert model_base_name == test_case ["model_base_name" ]
238- assert script == test_case ["script" ]
234+ assert model_base_name == test_case ["model_base_name" ]
235+ assert script == test_case ["script" ]
236+
237+
238+ def test_get_args_from_recipe_with_evaluation (temporary_recipe ):
239+ import tempfile
240+ import os
241+ from sagemaker .train .configs import SourceCode
242+
243+ # Create a recipe with evaluation config
244+ recipe_data = {
245+ "trainer" : {"num_nodes" : 1 },
246+ "model" : {"model_type" : "llama_v3" },
247+ "evaluation" : {"task" : "gen_qa" },
248+ "processor" : {"lambda_arn" : "arn:aws:lambda:us-east-1:123456789012:function:MyFunc" },
249+ }
250+
251+ with NamedTemporaryFile (suffix = ".yaml" , delete = False ) as f :
252+ with open (f .name , "w" ) as file :
253+ yaml .dump (recipe_data , file )
254+ recipe_path = f .name
255+
256+ try :
257+ compute = Compute (instance_type = "ml.p4d.24xlarge" , instance_count = 1 )
258+ with patch ("sagemaker.train.sm_recipes.utils._configure_gpu_args" ) as mock_gpu :
259+ mock_source = SourceCode ()
260+ mock_source .source_dir = "/tmp/test"
261+ mock_gpu .return_value = {"source_code" : mock_source , "hyperparameters" : {}}
262+ with patch ("sagemaker.train.sm_recipes.utils.OmegaConf.save" ):
263+ args , _ = _get_args_from_recipe (
264+ training_recipe = recipe_path ,
265+ compute = compute ,
266+ region_name = "us-west-2" ,
267+ recipe_overrides = None ,
268+ requirements = None ,
269+ )
270+ assert args ["hyperparameters" ]["lambda_arn" ] == "arn:aws:lambda:us-east-1:123456789012:function:MyFunc"
271+ finally :
272+ os .unlink (recipe_path )
0 commit comments