Skip to content

Commit 6feb6b6

Browse files
authored
Merge branch 'master' into master
2 parents 4d34e9b + cd95b19 commit 6feb6b6

File tree

22 files changed

+1125
-92
lines changed

22 files changed

+1125
-92
lines changed

sagemaker-train/src/sagemaker/ai_registry/dataset.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import pandas as pd
2626

27+
from sagemaker.ai_registry.dataset_format_detector import DatasetFormatDetector
2728
from sagemaker.ai_registry.air_hub import AIRHub
2829
from sagemaker.ai_registry.air_utils import _determine_new_version, _get_default_bucket
2930
from sagemaker.ai_registry.air_constants import (
@@ -179,6 +180,21 @@ def _validate_dataset_file(cls, file_path: str) -> None:
179180
max_size_mb = DATASET_MAX_FILE_SIZE_BYTES / (1024 * 1024)
180181
raise ValueError(f"File size {file_size_mb:.2f} MB exceeds maximum allowed size of {max_size_mb:.0f} MB")
181182

183+
@classmethod
184+
def _validate_dataset_format(cls, file_path: str) -> None:
185+
"""Validate dataset format using DatasetFormatDetector.
186+
187+
Args:
188+
file_path: Path to the dataset file (local path)
189+
190+
Raises:
191+
ValueError: If dataset format cannot be detected
192+
"""
193+
detector = DatasetFormatDetector()
194+
format_name = detector.validate_dataset(file_path)
195+
if format_name is False:
196+
raise ValueError(f"Unable to detect format for {file_path}. Please provide a valid dataset file.")
197+
182198
@classmethod
183199
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DataSet.get")
184200
def get(cls, name: str, sagemaker_session=None) -> "DataSet":
@@ -257,28 +273,25 @@ def create(
257273
s3_prefix = s3_key # Use full path including filename
258274
method = DataSetMethod.GENERATED
259275

260-
# Download and validate if customization technique is provided
261-
if customization_technique:
262-
with tempfile.NamedTemporaryFile(
263-
delete=False, suffix=os.path.splitext(s3_key)[1]
264-
) as tmp_file:
265-
local_path = tmp_file.name
266-
267-
try:
268-
AIRHub.download_from_s3(source, local_path)
269-
validate_dataset(local_path, customization_technique.value)
270-
finally:
271-
if os.path.exists(local_path):
272-
os.remove(local_path)
276+
# Download and validate format
277+
with tempfile.NamedTemporaryFile(
278+
delete=False, suffix=os.path.splitext(s3_key)[1]
279+
) as tmp_file:
280+
local_path = tmp_file.name
281+
282+
try:
283+
AIRHub.download_from_s3(source, local_path)
284+
cls._validate_dataset_format(local_path)
285+
finally:
286+
if os.path.exists(local_path):
287+
os.remove(local_path)
273288
else:
274289
# Local file - upload to S3
275290
bucket_name = _get_default_bucket()
276291
s3_prefix = _get_default_s3_prefix(name)
277292
method = DataSetMethod.UPLOADED
278293

279-
if customization_technique:
280-
validate_dataset(source, customization_technique.value)
281-
294+
cls._validate_dataset_format(source)
282295
AIRHub.upload_to_s3(bucket_name, s3_prefix, source)
283296

284297
# Create hub content document
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
import json
15+
from typing import Dict, Any, Optional
16+
from pathlib import Path
17+
18+
19+
class DatasetFormatDetector:
20+
"""Utility class for detecting dataset formats."""
21+
22+
# Schema directory
23+
SCHEMA_DIR = Path(__file__).parent / "schemas"
24+
25+
@staticmethod
26+
def _load_schema(format_name: str) -> Dict[str, Any]:
27+
"""Load JSON schema for a format."""
28+
schema_path = DatasetFormatDetector.SCHEMA_DIR / f"{format_name}.json"
29+
if schema_path.exists():
30+
with open(schema_path) as f:
31+
return json.load(f)
32+
return {}
33+
34+
@staticmethod
35+
def validate_dataset(file_path: str) -> bool:
36+
"""
37+
Validate if the dataset adheres to any known format.
38+
39+
Args:
40+
file_path: Path to the JSONL file
41+
42+
Returns:
43+
True if dataset is valid according to any known format, False otherwise
44+
"""
45+
import jsonschema
46+
47+
# Schema-based formats
48+
schema_formats = [
49+
"dpo", "converse", "hf_preference", "hf_prompt_completion",
50+
"verl", "openai_chat", "genqa"
51+
]
52+
53+
try:
54+
with open(file_path, 'r') as f:
55+
for line in f:
56+
line = line.strip()
57+
if line:
58+
data = json.loads(line)
59+
60+
# Try schema validation first
61+
for format_name in schema_formats:
62+
schema = DatasetFormatDetector._load_schema(format_name)
63+
if schema:
64+
try:
65+
jsonschema.validate(instance=data, schema=schema)
66+
return True
67+
except jsonschema.exceptions.ValidationError:
68+
continue
69+
70+
# Check for RFT-style format (messages + additional fields)
71+
if DatasetFormatDetector._is_rft_format(data):
72+
return True
73+
break
74+
return False
75+
except (json.JSONDecodeError, FileNotFoundError, IOError):
76+
return False
77+
78+
@staticmethod
79+
def _is_rft_format(data: Dict[str, Any]) -> bool:
80+
"""Check if data matches RFT format pattern."""
81+
if not isinstance(data, dict) or "messages" not in data:
82+
return False
83+
84+
messages = data["messages"]
85+
if not isinstance(messages, list) or not messages:
86+
return False
87+
88+
# Check message structure
89+
for msg in messages:
90+
if not isinstance(msg, dict):
91+
return False
92+
if "role" not in msg or "content" not in msg:
93+
return False
94+
if not isinstance(msg["role"], str) or not isinstance(msg["content"], str):
95+
return False
96+
97+
return True

0 commit comments

Comments
 (0)