Skip to content

Commit d089d40

Browse files
authored
Fix failing ai_registry unit tests (#5426)
1 parent cd95b19 commit d089d40

File tree

1 file changed

+54
-14
lines changed

1 file changed

+54
-14
lines changed

sagemaker-train/tests/unit/ai_registry/test_dataset_domain_id.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,58 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Unit tests for domain-id tagging in DataSet."""
14+
import json
15+
import tempfile
16+
import os
1417
import pytest
1518
from unittest.mock import Mock, patch, MagicMock
1619
from sagemaker.ai_registry.dataset import DataSet
1720
from sagemaker.ai_registry.dataset_utils import CustomizationTechnique
1821

1922

23+
# Sample RLVR format dataset (GSM8K style)
24+
SAMPLE_DATASET = {
25+
"data_source": "openai/gsm8k",
26+
"prompt": [{"content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\".", "role": "user"}],
27+
"ability": "math",
28+
"reward_model": {"ground_truth": "72", "style": "rule"},
29+
"extra_info": {"answer": "Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72", "index": 0, "question": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?", "split": "train"}
30+
}
31+
32+
33+
@pytest.fixture
34+
def sample_dataset_file():
35+
"""Create a temporary JSONL file with sample dataset."""
36+
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
37+
json.dump(SAMPLE_DATASET, f)
38+
temp_path = f.name
39+
40+
yield temp_path
41+
42+
# Cleanup
43+
if os.path.exists(temp_path):
44+
os.remove(temp_path)
45+
46+
2047
class TestDataSetDomainId:
2148
"""Test domain-id is added to SearchKeywords when available."""
2249

2350
@patch('sagemaker.core.helper.session_helper.Session')
2451
@patch('sagemaker.ai_registry.dataset._get_current_domain_id')
2552
@patch('sagemaker.ai_registry.dataset.AIRHub')
26-
@patch('sagemaker.ai_registry.dataset.validate_dataset')
53+
@patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session')
54+
@patch('sagemaker.train.defaults.TrainDefaults.get_role')
2755
def test_domain_id_added_when_available(
28-
self, mock_validate, mock_air_hub, mock_get_domain_id, mock_session
56+
self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file
2957
):
3058
"""Test that domain-id is added to tags when available."""
3159
# Setup mocks
3260
mock_domain_id = "d-test123456"
3361
mock_get_domain_id.return_value = mock_domain_id
34-
mock_session.return_value = Mock()
62+
mock_session_instance = Mock()
63+
mock_session.return_value = mock_session_instance
64+
mock_get_session.return_value = mock_session_instance
65+
mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role"
3566

3667
# Mock AIRHub methods
3768
mock_air_hub.upload_to_s3 = Mock()
@@ -46,11 +77,11 @@ def test_domain_id_added_when_available(
4677
'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}'
4778
})
4879

49-
# Create dataset
80+
# Create dataset with real file
5081
with patch('sagemaker.ai_registry.dataset.DataSet.wait'):
5182
dataset = DataSet.create(
5283
name="test-dataset",
53-
source="test-data.jsonl",
84+
source=sample_dataset_file,
5485
customization_technique=CustomizationTechnique.SFT
5586
)
5687

@@ -67,14 +98,18 @@ def test_domain_id_added_when_available(
6798
@patch('sagemaker.core.helper.session_helper.Session')
6899
@patch('sagemaker.ai_registry.dataset._get_current_domain_id')
69100
@patch('sagemaker.ai_registry.dataset.AIRHub')
70-
@patch('sagemaker.ai_registry.dataset.validate_dataset')
101+
@patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session')
102+
@patch('sagemaker.train.defaults.TrainDefaults.get_role')
71103
def test_domain_id_not_added_when_unavailable(
72-
self, mock_validate, mock_air_hub, mock_get_domain_id, mock_session
104+
self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file
73105
):
74106
"""Test that domain-id is not added when unavailable (non-Studio)."""
75107
# Setup mocks - domain_id returns None
76108
mock_get_domain_id.return_value = None
77-
mock_session.return_value = Mock()
109+
mock_session_instance = Mock()
110+
mock_session.return_value = mock_session_instance
111+
mock_get_session.return_value = mock_session_instance
112+
mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role"
78113

79114
# Mock AIRHub methods
80115
mock_air_hub.upload_to_s3 = Mock()
@@ -89,11 +124,11 @@ def test_domain_id_not_added_when_unavailable(
89124
'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}'
90125
})
91126

92-
# Create dataset
127+
# Create dataset with real file
93128
with patch('sagemaker.ai_registry.dataset.DataSet.wait'):
94129
dataset = DataSet.create(
95130
name="test-dataset",
96-
source="test-data.jsonl",
131+
source=sample_dataset_file,
97132
customization_technique=CustomizationTechnique.SFT
98133
)
99134

@@ -110,14 +145,19 @@ def test_domain_id_not_added_when_unavailable(
110145
@patch('sagemaker.core.helper.session_helper.Session')
111146
@patch('sagemaker.ai_registry.dataset._get_current_domain_id')
112147
@patch('sagemaker.ai_registry.dataset.AIRHub')
148+
@patch('sagemaker.train.defaults.TrainDefaults.get_sagemaker_session')
149+
@patch('sagemaker.train.defaults.TrainDefaults.get_role')
113150
def test_domain_id_added_without_customization_technique(
114-
self, mock_air_hub, mock_get_domain_id, mock_session
151+
self, mock_get_role, mock_get_session, mock_air_hub, mock_get_domain_id, mock_session, sample_dataset_file
115152
):
116153
"""Test that domain-id is added even without customization_technique."""
117154
# Setup mocks
118155
mock_domain_id = "d-test789"
119156
mock_get_domain_id.return_value = mock_domain_id
120-
mock_session.return_value = Mock()
157+
mock_session_instance = Mock()
158+
mock_session.return_value = mock_session_instance
159+
mock_get_session.return_value = mock_session_instance
160+
mock_get_role.return_value = "arn:aws:iam::123456789012:role/test-role"
121161

122162
# Mock AIRHub methods
123163
mock_air_hub.upload_to_s3 = Mock()
@@ -132,11 +172,11 @@ def test_domain_id_added_without_customization_technique(
132172
'HubContentDocument': '{"DatasetS3Bucket": "bucket", "DatasetS3Prefix": "prefix"}'
133173
})
134174

135-
# Create dataset WITHOUT customization_technique
175+
# Create dataset WITHOUT customization_technique using real file
136176
with patch('sagemaker.ai_registry.dataset.DataSet.wait'):
137177
dataset = DataSet.create(
138178
name="test-dataset",
139-
source="test-data.jsonl"
179+
source=sample_dataset_file
140180
# No customization_technique
141181
)
142182

0 commit comments

Comments
 (0)