Skip to content

Commit 8a4c359

Browse files
authored
Merge branch 'master' into master
2 parents bfcfae0 + 2c7c4b5 commit 8a4c359

File tree

26 files changed

+1139
-82
lines changed

26 files changed

+1139
-82
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
---
2+
name: Bug report
3+
about: File a report to help us reproduce and fix the problem
4+
title: ''
5+
labels: 'bug'
6+
assignees: ''
7+
8+
---
9+
10+
**PySDK Version**
11+
- [ ] PySDK V2 (2.x)
12+
- [ ] PySDK V3 (3.x)
13+
14+
**Describe the bug**
15+
A clear and concise description of what the bug is.
16+
17+
**To reproduce**
18+
A clear, step-by-step set of instructions to reproduce the bug.
19+
The provided code need to be **complete** and **runnable**, if additional data is needed, please include them in the issue.
20+
21+
**Expected behavior**
22+
A clear and concise description of what you expected to happen.
23+
24+
**Screenshots or logs**
25+
If applicable, add screenshots or logs to help explain your problem.
26+
27+
**System information**
28+
A description of your system. Please provide:
29+
- **SageMaker Python SDK version**:
30+
- **Framework name (eg. PyTorch) or algorithm (eg. KMeans)**:
31+
- **Framework version**:
32+
- **Python version**:
33+
- **CPU or GPU**:
34+
- **Custom Docker image (Y/N)**:
35+
36+
**Additional context**
37+
Add any other context about the problem here.

.github/ISSUE_TEMPLATE/config.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
blank_issues_enabled: false
2+
contact_links:
3+
- name: Ask a question
4+
url: https://github.com/aws/sagemaker-python-sdk/discussions
5+
about: Use GitHub Discussions to ask and answer questions
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
---
2+
name: Documentation request
3+
about: Request improved documentation
4+
title: ''
5+
labels: ''
6+
assignees: ''
7+
8+
---
9+
10+
**What did you find confusing? Please describe.**
11+
A clear and concise description of what you found confusing. Ex. I tried to [...] but I didn't understand how to [...]
12+
13+
**Describe how documentation can be improved**
14+
A clear and concise description of where documentation was lacking and how it can be improved.
15+
16+
**Additional context**
17+
Add any other context or screenshots about the documentation request here.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
---
2+
name: Feature request
3+
about: Suggest new functionality for this library
4+
title: ''
5+
labels: 'feature request'
6+
assignees: ''
7+
8+
---
9+
10+
**Describe the feature you'd like**
11+
A clear and concise description of the functionality you want.
12+
13+
**How would this feature be used? Please describe.**
14+
A clear and concise description of the use case for this feature. Please provide an example, if possible.
15+
16+
**Describe alternatives you've considered**
17+
A clear and concise description of any alternative solutions or features you've considered.
18+
19+
**Additional context**
20+
Add any other context or screenshots about the feature request here.

sagemaker-core/src/sagemaker/core/local/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ def get_child_process_ids(pid):
137137
Returns:
138138
(List[int]): Child process ids
139139
"""
140-
cmd = f"pgrep -P {pid}".split()
140+
if not str(pid).isdigit():
141+
raise ValueError("Invalid PID")
142+
143+
cmd = ["pgrep", "-P", str(pid)]
144+
141145
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
142146
output, err = process.communicate()
143147
if err:

sagemaker-core/tests/unit/local/test_local_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,21 +103,24 @@ def test_recursive_copy(copy_tree, m_os_path):
103103
@patch("sagemaker.core.local.utils.os")
104104
@patch("sagemaker.core.local.utils.get_child_process_ids")
105105
def test_kill_child_processes(m_get_child_process_ids, m_os):
106-
m_get_child_process_ids.return_value = ["child_pids"]
107-
kill_child_processes("pid")
108-
m_os.kill.assert_called_with("child_pids", 15)
106+
m_get_child_process_ids.return_value = ["345"]
107+
kill_child_processes("123")
108+
m_os.kill.assert_called_with("345", 15)
109109

110110

111111
@patch("sagemaker.core.local.utils.subprocess")
112112
def test_get_child_process_ids(m_subprocess):
113-
cmd = "pgrep -P pid".split()
113+
cmd = "pgrep -P 123".split()
114114
process_mock = Mock()
115115
attrs = {"communicate.return_value": (b"\n", False), "returncode": 0}
116116
process_mock.configure_mock(**attrs)
117117
m_subprocess.Popen.return_value = process_mock
118-
get_child_process_ids("pid")
118+
get_child_process_ids("123")
119119
m_subprocess.Popen.assert_called_with(cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE)
120120

121+
def test_get_child_process_ids_exception():
122+
with pytest.raises(ValueError, match="Invalid PID"):
123+
get_child_process_ids("abc")
121124

122125
@patch("sagemaker.core.local.utils.subprocess")
123126
def test_get_docker_host(m_subprocess):

sagemaker-train/src/sagemaker/ai_registry/dataset.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import pandas as pd
2626

27+
from sagemaker.ai_registry.dataset_format_detector import DatasetFormatDetector
2728
from sagemaker.ai_registry.air_hub import AIRHub
2829
from sagemaker.ai_registry.air_utils import _determine_new_version, _get_default_bucket
2930
from sagemaker.ai_registry.air_constants import (
@@ -179,6 +180,21 @@ def _validate_dataset_file(cls, file_path: str) -> None:
179180
max_size_mb = DATASET_MAX_FILE_SIZE_BYTES / (1024 * 1024)
180181
raise ValueError(f"File size {file_size_mb:.2f} MB exceeds maximum allowed size of {max_size_mb:.0f} MB")
181182

183+
@classmethod
184+
def _validate_dataset_format(cls, file_path: str) -> None:
185+
"""Validate dataset format using DatasetFormatDetector.
186+
187+
Args:
188+
file_path: Path to the dataset file (local path)
189+
190+
Raises:
191+
ValueError: If dataset format cannot be detected
192+
"""
193+
detector = DatasetFormatDetector()
194+
format_name = detector.validate_dataset(file_path)
195+
if format_name is False:
196+
raise ValueError(f"Unable to detect format for {file_path}. Please provide a valid dataset file.")
197+
182198
@classmethod
183199
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DataSet.get")
184200
def get(cls, name: str, sagemaker_session=None) -> "DataSet":
@@ -257,28 +273,25 @@ def create(
257273
s3_prefix = s3_key # Use full path including filename
258274
method = DataSetMethod.GENERATED
259275

260-
# Download and validate if customization technique is provided
261-
if customization_technique:
262-
with tempfile.NamedTemporaryFile(
263-
delete=False, suffix=os.path.splitext(s3_key)[1]
264-
) as tmp_file:
265-
local_path = tmp_file.name
266-
267-
try:
268-
AIRHub.download_from_s3(source, local_path)
269-
validate_dataset(local_path, customization_technique.value)
270-
finally:
271-
if os.path.exists(local_path):
272-
os.remove(local_path)
276+
# Download and validate format
277+
with tempfile.NamedTemporaryFile(
278+
delete=False, suffix=os.path.splitext(s3_key)[1]
279+
) as tmp_file:
280+
local_path = tmp_file.name
281+
282+
try:
283+
AIRHub.download_from_s3(source, local_path)
284+
cls._validate_dataset_format(local_path)
285+
finally:
286+
if os.path.exists(local_path):
287+
os.remove(local_path)
273288
else:
274289
# Local file - upload to S3
275290
bucket_name = _get_default_bucket()
276291
s3_prefix = _get_default_s3_prefix(name)
277292
method = DataSetMethod.UPLOADED
278293

279-
if customization_technique:
280-
validate_dataset(source, customization_technique.value)
281-
294+
cls._validate_dataset_format(source)
282295
AIRHub.upload_to_s3(bucket_name, s3_prefix, source)
283296

284297
# Create hub content document
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
import json
15+
from typing import Dict, Any, Optional
16+
from pathlib import Path
17+
18+
19+
class DatasetFormatDetector:
20+
"""Utility class for detecting dataset formats."""
21+
22+
# Schema directory
23+
SCHEMA_DIR = Path(__file__).parent / "schemas"
24+
25+
@staticmethod
26+
def _load_schema(format_name: str) -> Dict[str, Any]:
27+
"""Load JSON schema for a format."""
28+
schema_path = DatasetFormatDetector.SCHEMA_DIR / f"{format_name}.json"
29+
if schema_path.exists():
30+
with open(schema_path) as f:
31+
return json.load(f)
32+
return {}
33+
34+
@staticmethod
35+
def validate_dataset(file_path: str) -> bool:
36+
"""
37+
Validate if the dataset adheres to any known format.
38+
39+
Args:
40+
file_path: Path to the JSONL file
41+
42+
Returns:
43+
True if dataset is valid according to any known format, False otherwise
44+
"""
45+
import jsonschema
46+
47+
# Schema-based formats
48+
schema_formats = [
49+
"dpo", "converse", "hf_preference", "hf_prompt_completion",
50+
"verl", "openai_chat", "genqa"
51+
]
52+
53+
try:
54+
with open(file_path, 'r') as f:
55+
for line in f:
56+
line = line.strip()
57+
if line:
58+
data = json.loads(line)
59+
60+
# Try schema validation first
61+
for format_name in schema_formats:
62+
schema = DatasetFormatDetector._load_schema(format_name)
63+
if schema:
64+
try:
65+
jsonschema.validate(instance=data, schema=schema)
66+
return True
67+
except jsonschema.exceptions.ValidationError:
68+
continue
69+
70+
# Check for RFT-style format (messages + additional fields)
71+
if DatasetFormatDetector._is_rft_format(data):
72+
return True
73+
break
74+
return False
75+
except (json.JSONDecodeError, FileNotFoundError, IOError):
76+
return False
77+
78+
@staticmethod
79+
def _is_rft_format(data: Dict[str, Any]) -> bool:
80+
"""Check if data matches RFT format pattern."""
81+
if not isinstance(data, dict) or "messages" not in data:
82+
return False
83+
84+
messages = data["messages"]
85+
if not isinstance(messages, list) or not messages:
86+
return False
87+
88+
# Check message structure
89+
for msg in messages:
90+
if not isinstance(msg, dict):
91+
return False
92+
if "role" not in msg or "content" not in msg:
93+
return False
94+
if not isinstance(msg["role"], str) or not isinstance(msg["content"], str):
95+
return False
96+
97+
return True

0 commit comments

Comments
 (0)