Skip to content

Commit 6be0fd3

Browse files
authored
Updating benchmark evaluation for subtasks and datasets (#5378)
1 parent 60574e5 commit 6be0fd3

File tree

4 files changed

+83
-111
lines changed

4 files changed

+83
-111
lines changed

sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ class _Benchmark(str, Enum):
3535
MATH = "math"
3636
STRONG_REJECT = "strong_reject"
3737
IFEVAL = "ifeval"
38-
GEN_QA = "gen_qa"
3938
MMMU = "mmmu"
4039
LLM_JUDGE = "llm_judge"
41-
INFERENCE_ONLY = "inference_only"
4240

4341

4442
# Internal benchmark configuration mapping - using plain dictionaries
@@ -138,14 +136,6 @@ class _Benchmark(str, Enum):
138136
"subtask_available": False,
139137
"subtasks": None
140138
},
141-
_Benchmark.GEN_QA: {
142-
"modality": "Multi-Modal (image)",
143-
"description": "Custom Dataset Evaluation – Lets you supply your own dataset for benchmarking, comparing model outputs to reference answers with metrics such as ROUGE and BLEU. gen_qa supports image inference for models which have multimodal support.",
144-
"metrics": ["all"],
145-
"strategy": "gen_qa",
146-
"subtask_available": False,
147-
"subtasks": None
148-
},
149139
_Benchmark.MMMU: {
150140
"modality": "Multi-Modal",
151141
"description": "Massive Multidiscipline Multimodal Understanding (MMMU) – College-level benchmark comprising multiple-choice and open-ended questions from 30 disciplines.",
@@ -171,14 +161,6 @@ class _Benchmark(str, Enum):
171161
"subtask_available": False,
172162
"subtasks": None
173163
},
174-
_Benchmark.INFERENCE_ONLY: {
175-
"modality": "Text",
176-
"description": "Lets you supply your own dataset to generate inference responses which can be used with the llm_judge task. No metrics are computed for this task.",
177-
"metrics": ["N/A"],
178-
"strategy": "--",
179-
"subtask_available": False,
180-
"subtasks": None
181-
},
182164
}
183165

184166

@@ -278,10 +260,6 @@ class BenchMarkEvaluator(BaseEvaluator):
278260
Optional. If not provided, the system will attempt to resolve it using the default
279261
MLflow app experience (checks domain match, account default, or creates a new app).
280262
Format: arn:aws:sagemaker:region:account:mlflow-tracking-server/name
281-
dataset (Union[str, Any]): Evaluation dataset. Required. Accepts:
282-
- S3 URI (str): e.g., 's3://bucket/path/dataset.jsonl'
283-
- Dataset ARN (str): e.g., 'arn:aws:sagemaker:...:hub-content/AIRegistry/DataSet/...'
284-
- DataSet object: sagemaker.ai_registry.dataset.DataSet instance (ARN inferred automatically)
285263
evaluate_base_model (bool): Whether to evaluate the base model in addition to the custom
286264
model. Set to False to skip base model evaluation and only evaluate the custom model.
287265
Defaults to True (evaluates both models).
@@ -309,7 +287,6 @@ class BenchMarkEvaluator(BaseEvaluator):
309287
benchmark=Benchmark.MMLU,
310288
subtasks=["abstract_algebra", "anatomy", "astronomy"],
311289
model="llama3-2-1b-instruct",
312-
dataset="s3://bucket/eval-data.jsonl",
313290
s3_output_path="s3://bucket/outputs/",
314291
mlflow_resource_arn="arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/my-server"
315292
)
@@ -327,16 +304,8 @@ class BenchMarkEvaluator(BaseEvaluator):
327304
_hyperparameters: Optional[Any] = None
328305

329306
# Template-required fields
330-
dataset: Union[str, Any]
331-
evaluate_base_model: bool = True
332-
333-
@validator('dataset', pre=True)
334-
def _resolve_dataset(cls, v):
335-
"""Resolve dataset to string (S3 URI or ARN) and validate format.
336-
337-
Uses BaseEvaluator's common validation logic to avoid code duplication.
338-
"""
339-
return BaseEvaluator._validate_and_resolve_dataset(v)
307+
evaluate_base_model: bool = False
308+
340309

341310
@validator('benchmark')
342311
def _validate_benchmark_model_compatibility(cls, v, values):
@@ -385,15 +354,21 @@ def _validate_subtasks(cls, v, values):
385354
f"Subtask list cannot be empty for benchmark '{benchmark.value}'. "
386355
f"Provide at least one subtask or use 'ALL'."
387356
)
388-
357+
if len(v) > 1 :
358+
raise ValueError(
359+
f"Currently only one subtask is supported for benchmark '{benchmark.value}'. "
360+
f"Provide only one subtask or use 'ALL'."
361+
)
362+
363+
# TODO : Should support list of subtasks.
389364
# Validate each subtask in the list
390365
for subtask in v:
391366
if not isinstance(subtask, str):
392367
raise ValueError(
393368
f"All subtasks in the list must be strings. "
394369
f"Found {type(subtask).__name__}: {subtask}"
395370
)
396-
371+
397372
# Validate against available subtasks if defined
398373
if config.get("subtasks") and subtask not in config["subtasks"]:
399374
raise ValueError(
@@ -527,23 +502,32 @@ def _resolve_subtask_for_evaluation(self, subtask: Optional[Union[str, List[str]
527502
"""
528503
# Use provided subtask or fall back to constructor subtasks
529504
eval_subtask = subtask if subtask is not None else self.subtasks
530-
505+
506+
if eval_subtask is None or eval_subtask.upper() == "ALL":
507+
#TODO : Check All Vs None subtask for evaluation
508+
return None
509+
531510
# Validate the subtask
532511
config = _BENCHMARK_CONFIG.get(self.benchmark)
533512
if config and config.get("subtask_available"):
534-
if isinstance(eval_subtask, list):
535-
for st in eval_subtask:
536-
if config.get("subtasks") and st not in config["subtasks"] and st.upper() != "ALL":
537-
raise ValueError(
538-
f"Invalid subtask '{st}' for benchmark '{self.benchmark.value}'. "
539-
f"Available subtasks: {', '.join(config['subtasks'])}"
540-
)
541-
elif isinstance(eval_subtask, str):
513+
if isinstance(eval_subtask, str):
542514
if eval_subtask.upper() != "ALL" and config.get("subtasks") and eval_subtask not in config["subtasks"]:
543515
raise ValueError(
544516
f"Invalid subtask '{eval_subtask}' for benchmark '{self.benchmark.value}'. "
545517
f"Available subtasks: {', '.join(config['subtasks'])}"
546518
)
519+
elif isinstance(eval_subtask, list):
520+
if len(eval_subtask) == 0:
521+
raise ValueError(
522+
f"Subtask list cannot be empty for benchmark '{self.benchmark.value}'. "
523+
f"Provide at least one subtask or use 'ALL'."
524+
)
525+
if len(eval_subtask) > 1:
526+
raise ValueError(
527+
f"Currently only one subtask is supported for benchmark '{self.benchmark.value}'. "
528+
f"Provide only one subtask or use 'ALL'."
529+
)
530+
547531

548532
return eval_subtask
549533

@@ -573,10 +557,12 @@ def _get_benchmark_template_additions(self, eval_subtask: Optional[Union[str, Li
573557
'task': self.benchmark.value,
574558
'strategy': config["strategy"],
575559
metric_key: config["metrics"][0] if config.get("metrics") else 'accuracy',
576-
'subtask': eval_subtask if isinstance(eval_subtask, str) else ','.join(eval_subtask) if eval_subtask else '',
577560
'evaluate_base_model': self.evaluate_base_model,
578561
}
579562

563+
if isinstance(eval_subtask, str):
564+
benchmark_context['subtask'] = eval_subtask
565+
580566
# Add all configured hyperparameters
581567
for key in configured_params.keys():
582568
benchmark_context[key] = configured_params[key]
@@ -604,7 +590,6 @@ def evaluate(self, subtask: Optional[Union[str, List[str]]] = None) -> Evaluatio
604590
benchmark=Benchmark.MMLU,
605591
subtasks="ALL",
606592
model="llama3-2-1b-instruct",
607-
dataset="s3://bucket/data.jsonl",
608593
s3_output_path="s3://bucket/outputs/"
609594
)
610595
@@ -645,9 +630,7 @@ def evaluate(self, subtask: Optional[Union[str, List[str]]] = None) -> Evaluatio
645630
model_package_group_arn=model_package_group_arn,
646631
resolved_model_artifact_arn=artifacts['resolved_model_artifact_arn']
647632
)
648-
649-
# Add dataset URI
650-
template_context['dataset_uri'] = self.dataset
633+
651634

652635
# Add benchmark-specific template additions
653636
benchmark_additions = self._get_benchmark_template_additions(eval_subtask, config)

sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@
129129
{% if kms_key_id %},
130130
"KmsKeyId": "{{ kms_key_id }}"
131131
{% endif %}
132-
},
132+
}{% if dataset_uri %},
133133
"InputDataConfig": [
134134
{
135135
"ChannelName": "train",
@@ -144,7 +144,7 @@
144144
}
145145
}{% endif %}
146146
}
147-
]{% if vpc_config %},
147+
]{% endif %}{% if vpc_config %},
148148
"VpcConfig": {
149149
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
150150
"Subnets": {{ vpc_subnets | tojson }}
@@ -191,7 +191,7 @@
191191
{% if kms_key_id %},
192192
"KmsKeyId": "{{ kms_key_id }}"
193193
{% endif %}
194-
},
194+
}{% if dataset_uri %},
195195
"InputDataConfig": [
196196
{
197197
"ChannelName": "train",
@@ -206,7 +206,7 @@
206206
}
207207
}{% endif %}
208208
}
209-
]{% if vpc_config %},
209+
]{% endif %}{% if vpc_config %},
210210
"VpcConfig": {
211211
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
212212
"Subnets": {{ vpc_subnets | tojson }}
@@ -358,7 +358,7 @@
358358
{% if kms_key_id %},
359359
"KmsKeyId": "{{ kms_key_id }}"
360360
{% endif %}
361-
},
361+
}{% if dataset_uri %},
362362
"InputDataConfig": [
363363
{
364364
"ChannelName": "train",
@@ -373,7 +373,7 @@
373373
}
374374
}{% endif %}
375375
}
376-
]{% if vpc_config %},
376+
]{% endif %}{% if vpc_config %},
377377
"VpcConfig": {
378378
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
379379
"Subnets": {{ vpc_subnets | tojson }}
@@ -500,7 +500,7 @@
500500
{% if kms_key_id %},
501501
"KmsKeyId": "{{ kms_key_id }}"
502502
{% endif %}
503-
},
503+
}{% if dataset_uri %},
504504
"InputDataConfig": [
505505
{
506506
"ChannelName": "train",
@@ -515,7 +515,7 @@
515515
}
516516
}{% endif %}
517517
}
518-
]{% if vpc_config %},
518+
]{% endif %}{% if vpc_config %},
519519
"VpcConfig": {
520520
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
521521
"Subnets": {{ vpc_subnets | tojson }}
@@ -650,7 +650,7 @@
650650
{% if kms_key_id %},
651651
"KmsKeyId": "{{ kms_key_id }}"
652652
{% endif %}
653-
},
653+
}{% if dataset_uri %},
654654
"InputDataConfig": [
655655
{
656656
"ChannelName": "train",
@@ -665,7 +665,7 @@
665665
}
666666
}{% endif %}
667667
}
668-
]{% if vpc_config %},
668+
]{% endif %}{% if vpc_config %},
669669
"VpcConfig": {
670670
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
671671
"Subnets": {{ vpc_subnets | tojson }}
@@ -713,7 +713,7 @@
713713
{% if kms_key_id %},
714714
"KmsKeyId": "{{ kms_key_id }}"
715715
{% endif %}
716-
},
716+
}{% if dataset_uri %},
717717
"InputDataConfig": [
718718
{
719719
"ChannelName": "train",
@@ -728,7 +728,7 @@
728728
}
729729
}{% endif %}
730730
}
731-
]{% if vpc_config %},
731+
]{% endif %}{% if vpc_config %},
732732
"VpcConfig": {
733733
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
734734
"Subnets": {{ vpc_subnets | tojson }}
@@ -892,7 +892,7 @@
892892
{% if kms_key_id %},
893893
"KmsKeyId": "{{ kms_key_id }}"
894894
{% endif %}
895-
},
895+
}{% if dataset_uri %},
896896
"InputDataConfig": [
897897
{
898898
"ChannelName": "train",
@@ -907,7 +907,7 @@
907907
}
908908
}{% endif %}
909909
}
910-
]{% if vpc_config %},
910+
]{% endif %}{% if vpc_config %},
911911
"VpcConfig": {
912912
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
913913
"Subnets": {{ vpc_subnets | tojson }}
@@ -1032,7 +1032,7 @@
10321032
"ModelPackageConfig": {
10331033
"ModelPackageGroupArn": "{{ model_package_group_arn }}",
10341034
"SourceModelPackageArn": "{{ source_model_package_arn }}"
1035-
},
1035+
}{% if dataset_uri %},
10361036
"InputDataConfig": [
10371037
{
10381038
"ChannelName": "train",
@@ -1047,7 +1047,7 @@
10471047
}
10481048
}{% endif %}
10491049
}
1050-
]{% if vpc_config %},
1050+
]{% endif %}{% if vpc_config %},
10511051
"VpcConfig": {
10521052
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
10531053
"Subnets": {{ vpc_subnets | tojson }}
@@ -1086,7 +1086,7 @@
10861086
"ModelPackageConfig": {
10871087
"ModelPackageGroupArn": "{{ model_package_group_arn }}",
10881088
"SourceModelPackageArn": "{{ source_model_package_arn }}"
1089-
},
1089+
}{% if dataset_uri %},
10901090
"InputDataConfig": [
10911091
{
10921092
"ChannelName": "train",
@@ -1101,7 +1101,7 @@
11011101
}
11021102
}{% endif %}
11031103
}
1104-
]{% if vpc_config %},
1104+
]{% endif %}{% if vpc_config %},
11051105
"VpcConfig": {
11061106
"SecurityGroupIds": {{ vpc_security_group_ids | tojson }},
11071107
"Subnets": {{ vpc_subnets | tojson }}

0 commit comments

Comments
 (0)