diff --git a/sagemaker-train/src/sagemaker/ai_registry/dataset.py b/sagemaker-train/src/sagemaker/ai_registry/dataset.py index 229899f3a6..92aeef23f3 100644 --- a/sagemaker-train/src/sagemaker/ai_registry/dataset.py +++ b/sagemaker-train/src/sagemaker/ai_registry/dataset.py @@ -24,6 +24,7 @@ import pandas as pd +from sagemaker.ai_registry.dataset_format_detector import DatasetFormatDetector from sagemaker.ai_registry.air_hub import AIRHub from sagemaker.ai_registry.air_utils import _determine_new_version, _get_default_bucket from sagemaker.ai_registry.air_constants import ( @@ -179,6 +180,21 @@ def _validate_dataset_file(cls, file_path: str) -> None: max_size_mb = DATASET_MAX_FILE_SIZE_BYTES / (1024 * 1024) raise ValueError(f"File size {file_size_mb:.2f} MB exceeds maximum allowed size of {max_size_mb:.0f} MB") + @classmethod + def _validate_dataset_format(cls, file_path: str) -> None: + """Validate dataset format using DatasetFormatDetector. + + Args: + file_path: Path to the dataset file (local path) + + Raises: + ValueError: If dataset format cannot be detected + """ + detector = DatasetFormatDetector() + format_name = detector.validate_dataset(file_path) + if format_name is False: + raise ValueError(f"Unable to detect format for {file_path}. Please provide a valid dataset file.") + @classmethod @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DataSet.get") def get(cls, name: str, sagemaker_session=None) -> "DataSet": @@ -257,28 +273,25 @@ def create( s3_prefix = s3_key # Use full path including filename method = DataSetMethod.GENERATED - # Download and validate if customization technique is provided - if customization_technique: - with tempfile.NamedTemporaryFile( - delete=False, suffix=os.path.splitext(s3_key)[1] - ) as tmp_file: - local_path = tmp_file.name - - try: - AIRHub.download_from_s3(source, local_path) - validate_dataset(local_path, customization_technique.value) - finally: - if os.path.exists(local_path): - os.remove(local_path) + # Download and validate format + with tempfile.NamedTemporaryFile( + delete=False, suffix=os.path.splitext(s3_key)[1] + ) as tmp_file: + local_path = tmp_file.name + + try: + AIRHub.download_from_s3(source, local_path) + cls._validate_dataset_format(local_path) + finally: + if os.path.exists(local_path): + os.remove(local_path) else: # Local file - upload to S3 bucket_name = _get_default_bucket() s3_prefix = _get_default_s3_prefix(name) method = DataSetMethod.UPLOADED - if customization_technique: - validate_dataset(source, customization_technique.value) - + cls._validate_dataset_format(source) AIRHub.upload_to_s3(bucket_name, s3_prefix, source) # Create hub content document diff --git a/sagemaker-train/src/sagemaker/ai_registry/dataset_format_detector.py b/sagemaker-train/src/sagemaker/ai_registry/dataset_format_detector.py new file mode 100644 index 0000000000..d248932fb2 --- /dev/null +++ b/sagemaker-train/src/sagemaker/ai_registry/dataset_format_detector.py @@ -0,0 +1,97 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import json +from typing import Dict, Any, Optional +from pathlib import Path + + +class DatasetFormatDetector: + """Utility class for detecting dataset formats.""" + + # Schema directory + SCHEMA_DIR = Path(__file__).parent / "schemas" + + @staticmethod + def _load_schema(format_name: str) -> Dict[str, Any]: + """Load JSON schema for a format.""" + schema_path = DatasetFormatDetector.SCHEMA_DIR / f"{format_name}.json" + if schema_path.exists(): + with open(schema_path) as f: + return json.load(f) + return {} + + @staticmethod + def validate_dataset(file_path: str) -> bool: + """ + Validate if the dataset adheres to any known format. + + Args: + file_path: Path to the JSONL file + + Returns: + True if dataset is valid according to any known format, False otherwise + """ + import jsonschema + + # Schema-based formats + schema_formats = [ + "dpo", "converse", "hf_preference", "hf_prompt_completion", + "verl", "openai_chat", "genqa" + ] + + try: + with open(file_path, 'r') as f: + for line in f: + line = line.strip() + if line: + data = json.loads(line) + + # Try schema validation first + for format_name in schema_formats: + schema = DatasetFormatDetector._load_schema(format_name) + if schema: + try: + jsonschema.validate(instance=data, schema=schema) + return True + except jsonschema.exceptions.ValidationError: + continue + + # Check for RFT-style format (messages + additional fields) + if DatasetFormatDetector._is_rft_format(data): + return True + break + return False + except (json.JSONDecodeError, FileNotFoundError, IOError): + return False + + @staticmethod + def _is_rft_format(data: Dict[str, Any]) -> bool: + """Check if data matches RFT format pattern.""" + if not isinstance(data, dict) or "messages" not in data: + return False + + messages = data["messages"] + if not isinstance(messages, list) or not messages: + return False + + # Check message structure + for msg in messages: + if not isinstance(msg, dict): + return False + if "role" not in msg or "content" not in msg: + return False + if not isinstance(msg["role"], str) or not isinstance(msg["content"], str): + return False + + return True diff --git a/sagemaker-train/src/sagemaker/ai_registry/schemas/converse.json b/sagemaker-train/src/sagemaker/ai_registry/schemas/converse.json new file mode 100644 index 0000000000..df9655f66f --- /dev/null +++ b/sagemaker-train/src/sagemaker/ai_registry/schemas/converse.json @@ -0,0 +1,234 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "AWS Converse Format", + "description": "AWS Bedrock Converse API format for training data", + "type": "object", + "required": ["messages"], + "properties": { + "schemaVersion": { + "type": "string" + }, + "system": { + "type": "array", + "items": { + "type": "object", + "required": ["text"], + "properties": { + "text": { + "type": "string" + } + }, + "additionalProperties": false + } + }, + "messages": { + "type": "array", + "minItems": 2, + "items": { + "type": "object", + "required": ["role", "content"], + "properties": { + "role": { + "type": "string", + "enum": ["user", "assistant"] + }, + "content": { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "properties": { + "text": { + "type": "string" + }, + "image": { + "type": "object", + "required": ["format", "source"], + "properties": { + "format": { + "type": "string", + "enum": ["jpeg", "png", "gif", "webp"] + }, + "source": { + "type": "object", + "required": ["s3Location"], + "properties": { + "s3Location": { + "type": "object", + "required": ["localPath"], + "properties": { + "localPath": { + "type": "string" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + }, + "video": { + "type": "object", + "required": ["format", "source"], + "properties": { + "format": { + "type": "string", + "enum": ["mov", "mkv", "mp4", "webm"] + }, + "source": { + "type": "object", + "required": ["s3Location"], + "properties": { + "s3Location": { + "type": "object", + "required": ["localPath"], + "properties": { + "localPath": { + "type": "string" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + }, + "document": { + "type": "object", + "required": ["format", "source"], + "properties": { + "format": { + "type": "string", + "enum": ["pdf"] + }, + "source": { + "type": "object", + "required": ["s3Location"], + "properties": { + "s3Location": { + "type": "object", + "required": ["localPath"], + "properties": { + "localPath": { + "type": "string" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + }, + "toolUse": { + "type": "object", + "required": ["toolUseId", "name", "input"], + "properties": { + "toolUseId": { + "type": "string" + }, + "name": { + "type": "string" + }, + "input": { + "type": "object" + } + }, + "additionalProperties": false + }, + "toolResult": { + "type": "object", + "required": ["toolUseId", "content"], + "properties": { + "toolUseId": { + "type": "string" + }, + "content": { + "type": "array", + "items": { + "type": "object", + "properties": { + "json": { + "type": "object" + }, + "text": { + "type": "string" + } + }, + "additionalProperties": true + } + } + }, + "additionalProperties": false + }, + "reasoningContent": { + "type": "object", + "required": ["reasoningText"], + "properties": { + "reasoningText": { + "type": "object", + "required": ["text"], + "properties": { + "text": { + "type": "string" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + } + }, + "additionalProperties": true + } + } + }, + "additionalProperties": false + } + }, + "toolConfig": { + "type": "object", + "required": ["tools"], + "properties": { + "tools": { + "type": "array", + "items": { + "type": "object", + "properties": { + "toolSpec": { + "type": "object", + "required": ["name"], + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "inputSchema": { + "type": "object", + "properties": { + "json": { + "type": "object" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + } + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false +} diff --git a/sagemaker-train/src/sagemaker/ai_registry/schemas/dpo.json b/sagemaker-train/src/sagemaker/ai_registry/schemas/dpo.json new file mode 100644 index 0000000000..c855735639 --- /dev/null +++ b/sagemaker-train/src/sagemaker/ai_registry/schemas/dpo.json @@ -0,0 +1,97 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "DPO (Direct Preference Optimization) Format", + "description": "DPO format with preference candidates based on ConverseDatasetSampleWithCandidates", + "type": "object", + "required": ["messages"], + "properties": { + "schemaVersion": { + "type": "string" + }, + "system": { + "type": "array", + "items": { + "type": "object", + "required": ["text"], + "properties": { + "text": { + "type": "string" + } + }, + "additionalProperties": false + } + }, + "messages": { + "type": "array", + "minItems": 2, + "items": { + "oneOf": [ + { + "type": "object", + "required": ["role", "content"], + "properties": { + "role": { + "type": "string", + "enum": ["user", "assistant"] + }, + "content": { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "properties": { + "text": { + "type": "string" + } + }, + "additionalProperties": true + } + } + }, + "additionalProperties": false + }, + { + "type": "object", + "required": ["role", "candidates"], + "properties": { + "role": { + "type": "string", + "enum": ["assistant"] + }, + "candidates": { + "type": "array", + "minItems": 2, + "items": { + "type": "object", + "required": ["content", "preferenceLabel"], + "properties": { + "content": { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "properties": { + "text": { + "type": "string" + } + }, + "additionalProperties": true + } + }, + "preferenceLabel": { + "type": "string", + "enum": ["preferred", "non-preferred"] + } + }, + "additionalProperties": false + } + } + }, + "additionalProperties": false + } + ] + } + } + }, + "additionalProperties": false +} diff --git a/sagemaker-train/src/sagemaker/ai_registry/schemas/genqa.json b/sagemaker-train/src/sagemaker/ai_registry/schemas/genqa.json new file mode 100644 index 0000000000..fd6d800737 --- /dev/null +++ b/sagemaker-train/src/sagemaker/ai_registry/schemas/genqa.json @@ -0,0 +1,37 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "description": "GenQA format for evaluation", + "properties": { + "query": { + "type": "string", + "description": "Required query to the model" + }, + "response": { + "type": "string", + "description": "Optional target response (required for evaluation, optional for inference-only)" + }, + "system": { + "type": "string", + "description": "Optional system prompt for the model" + }, + "metadata": { + "type": "object", + "description": "Optional metadata for labeling purposes" + }, + "images": { + "type": "array", + "items": { + "type": "object", + "properties": { + "data": { + "type": "string" + } + }, + "required": ["data"] + } + } + }, + "required": ["query"], + "additionalProperties": true +} diff --git a/sagemaker-train/src/sagemaker/ai_registry/schemas/hf_preference.json b/sagemaker-train/src/sagemaker/ai_registry/schemas/hf_preference.json new file mode 100644 index 0000000000..e10291a5e2 --- /dev/null +++ b/sagemaker-train/src/sagemaker/ai_registry/schemas/hf_preference.json @@ -0,0 +1,117 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "description": "HuggingFace Preference format for DPO/RLHF training", + "properties": { + "input": { + "oneOf": [ + { + "type": "string", + "description": "The input prompt or query as a string" + }, + { + "type": "array", + "description": "Array of conversation messages", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string" + }, + "content": { + "type": "string" + } + }, + "required": ["role", "content"] + } + } + ] + }, + "prompt": { + "oneOf": [ + { + "type": "string", + "description": "Alternative field for the input prompt as a string" + }, + { + "type": "array", + "description": "Array of conversation messages", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string" + }, + "content": { + "type": "string" + } + }, + "required": ["role", "content"] + } + } + ] + }, + "chosen": { + "oneOf": [ + { + "type": "string", + "description": "The preferred/chosen response as a string" + }, + { + "type": "array", + "description": "Array of assistant message(s) for chosen response", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string" + }, + "content": { + "type": "string" + } + }, + "required": ["role", "content"] + } + } + ] + }, + "rejected": { + "oneOf": [ + { + "type": "string", + "description": "The rejected/non-preferred response as a string" + }, + { + "type": "array", + "description": "Array of assistant message(s) for rejected response", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string" + }, + "content": { + "type": "string" + } + }, + "required": ["role", "content"] + } + } + ] + }, + "id": { + "type": "string", + "description": "Optional unique identifier" + }, + "attributes": { + "type": "object", + "description": "Optional metadata attributes" + }, + "difficulty": { + "type": "string", + "description": "Optional difficulty level" + } + }, + "required": ["chosen", "rejected"], + "additionalProperties": true +} diff --git a/sagemaker-train/src/sagemaker/ai_registry/schemas/hf_prompt_completion.json b/sagemaker-train/src/sagemaker/ai_registry/schemas/hf_prompt_completion.json new file mode 100644 index 0000000000..fc29944b84 --- /dev/null +++ b/sagemaker-train/src/sagemaker/ai_registry/schemas/hf_prompt_completion.json @@ -0,0 +1,69 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "description": "HuggingFace Prompt-Completion format for supervised fine-tuning", + "properties": { + "prompt": { + "oneOf": [ + { + "type": "string", + "description": "The input prompt as a string" + }, + { + "type": "array", + "description": "Array of conversation messages", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string" + }, + "content": { + "type": "string" + } + }, + "required": ["role", "content"] + } + } + ] + }, + "completion": { + "oneOf": [ + { + "type": "string", + "description": "The target completion as a string" + }, + { + "type": "array", + "description": "Array of assistant message(s)", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string" + }, + "content": { + "type": "string" + } + }, + "required": ["role", "content"] + } + } + ] + }, + "id": { + "type": "string", + "description": "Optional unique identifier" + }, + "attributes": { + "type": "object", + "description": "Optional metadata attributes" + }, + "difficulty": { + "type": "string", + "description": "Optional difficulty level" + } + }, + "required": ["prompt", "completion"], + "additionalProperties": true +} diff --git a/sagemaker-train/src/sagemaker/ai_registry/schemas/openai_chat.json b/sagemaker-train/src/sagemaker/ai_registry/schemas/openai_chat.json new file mode 100644 index 0000000000..e9c3599f4f --- /dev/null +++ b/sagemaker-train/src/sagemaker/ai_registry/schemas/openai_chat.json @@ -0,0 +1,25 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "OpenAI Chat Format", + "description": "OpenAI chat completion format with messages array", + "type": "object", + "required": ["messages"], + "properties": { + "messages": { + "type": "array", + "items": { + "type": "object", + "required": ["role", "content"], + "properties": { + "role": { + "type": "string", + "enum": ["system", "user", "assistant"] + }, + "content": { + "type": "string" + } + } + } + } + } +} diff --git a/sagemaker-train/src/sagemaker/ai_registry/schemas/verl.json b/sagemaker-train/src/sagemaker/ai_registry/schemas/verl.json new file mode 100644 index 0000000000..656a5bdfe1 --- /dev/null +++ b/sagemaker-train/src/sagemaker/ai_registry/schemas/verl.json @@ -0,0 +1,64 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "description": "VERL format for reinforcement learning training", + "properties": { + "prompt": { + "oneOf": [ + { + "type": "string", + "description": "The prompt/query as a string (legacy format)" + }, + { + "type": "array", + "description": "Array of conversation messages", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string" + }, + "content": { + "type": "string" + } + }, + "required": ["role", "content"] + } + } + ] + }, + "id": { + "type": "string", + "description": "Optional unique identifier" + }, + "attributes": { + "oneOf": [ + {"type": "object"}, + {"type": "string"} + ], + "description": "Optional metadata attributes" + }, + "difficulty": { + "type": "string", + "description": "Optional difficulty level" + }, + "data_source": { + "type": "string", + "description": "Optional source of the data" + }, + "ability": { + "type": "string", + "description": "Optional ability category" + }, + "reward_model": { + "type": "object", + "description": "Optional reward model configuration" + }, + "extra_info": { + "type": "object", + "description": "Optional additional information" + } + }, + "required": ["prompt"], + "additionalProperties": true +} diff --git a/sagemaker-train/tests/integ/ai_registry/test_dataset.py b/sagemaker-train/tests/integ/ai_registry/test_dataset.py index 5339b5cd42..f29c9e09e0 100644 --- a/sagemaker-train/tests/integ/ai_registry/test_dataset.py +++ b/sagemaker-train/tests/integ/ai_registry/test_dataset.py @@ -19,6 +19,7 @@ from sagemaker.ai_registry.dataset import DataSet from sagemaker.ai_registry.dataset_utils import CustomizationTechnique from sagemaker.ai_registry.air_constants import HubContentStatus +from unittest.mock import patch @pytest.mark.serial @@ -39,9 +40,9 @@ def test_create_dataset_from_local_file(self, unique_name, sample_jsonl_file, cl assert dataset.version is not None assert dataset.customization_technique == CustomizationTechnique.SFT - def test_create_dataset_from_s3_uri_sft(self, unique_name, test_bucket, cleanup_list): + def test_create_dataset_from_s3_oss_sft(self, unique_name, test_bucket, cleanup_list): """Test creating SFT dataset from S3 URI.""" - s3_uri = f"s3://{test_bucket}/test-sft-ds1.jsonl" + s3_uri = f"s3://{test_bucket}/test_datasets/OSS/oss_sft_train.jsonl" dataset = DataSet.create( name=unique_name, source=s3_uri, @@ -52,9 +53,22 @@ def test_create_dataset_from_s3_uri_sft(self, unique_name, test_bucket, cleanup_ assert dataset.name == unique_name assert dataset.customization_technique == CustomizationTechnique.SFT - def test_create_dataset_from_s3_uri_dpo(self, unique_name, test_bucket, cleanup_list): - """Test creating DPO dataset from S3 URI.""" - s3_uri = f"s3://{test_bucket}/preference_dataset_train_256.jsonl" + def test_create_dataset_from_s3_oss_rlvr(self, unique_name, test_bucket, cleanup_list): + """Test creating RLVR dataset from S3 URI.""" + s3_uri = f"s3://{test_bucket}/test_datasets/OSS/oss_rlvr_train.jsonl" + dataset = DataSet.create( + name=unique_name, + source=s3_uri, + customization_technique=CustomizationTechnique.RLVR, + wait=False + ) + cleanup_list.append(dataset) + assert dataset.name == unique_name + assert dataset.customization_technique == CustomizationTechnique.RLVR + + def test_create_dataset_from_s3_oss_dpo(self, unique_name, test_bucket, cleanup_list): + """Test creating RLVR dataset from S3 URI.""" + s3_uri = f"s3://{test_bucket}/test_datasets/OSS/oss_dpo_train.jsonl" dataset = DataSet.create( name=unique_name, source=s3_uri, @@ -65,6 +79,56 @@ def test_create_dataset_from_s3_uri_dpo(self, unique_name, test_bucket, cleanup_ assert dataset.name == unique_name assert dataset.customization_technique == CustomizationTechnique.DPO + def test_create_dataset_from_s3_nova_sft(self, unique_name, test_bucket, cleanup_list): + """Test creating RLVR dataset from S3 URI.""" + s3_uri = f"s3://{test_bucket}/test_datasets/Nova/nova_sft_train.jsonl" + dataset = DataSet.create( + name=unique_name, + source=s3_uri, + customization_technique=CustomizationTechnique.SFT, + wait=False + ) + cleanup_list.append(dataset) + assert dataset.name == unique_name + assert dataset.customization_technique == CustomizationTechnique.SFT + + def test_create_dataset_from_s3_nova_dpo(self, unique_name, test_bucket, cleanup_list): + """Test creating RLVR dataset from S3 URI.""" + s3_uri = f"s3://{test_bucket}/test_datasets/Nova/nova_dpo_train.jsonl" + dataset = DataSet.create( + name=unique_name, + source=s3_uri, + customization_technique=CustomizationTechnique.DPO, + wait=False + ) + cleanup_list.append(dataset) + assert dataset.name == unique_name + assert dataset.customization_technique == CustomizationTechnique.DPO + + def test_create_dataset_from_s3_nova_rft(self, unique_name, test_bucket, cleanup_list): + """Test creating RLVR dataset from S3 URI.""" + s3_uri = f"s3://{test_bucket}/test_datasets/Nova/nova_rft_train.jsonl" + dataset = DataSet.create( + name=unique_name, + source=s3_uri, + customization_technique=CustomizationTechnique.RLVR, + wait=False + ) + cleanup_list.append(dataset) + assert dataset.name == unique_name + assert dataset.customization_technique == CustomizationTechnique.RLVR + + def test_create_dataset_from_s3_nova_eval(self, unique_name, test_bucket, cleanup_list): + """Test creating RLVR dataset from S3 URI.""" + s3_uri = f"s3://{test_bucket}/test_datasets/Nova/nova_eval.jsonl" + dataset = DataSet.create( + name=unique_name, + source=s3_uri, + wait=False + ) + cleanup_list.append(dataset) + assert dataset.name == unique_name + def test_get_dataset(self, unique_name, sample_jsonl_file): """Test retrieving dataset by name.""" created = DataSet.create(name=unique_name, source=sample_jsonl_file, wait=False) @@ -122,6 +186,34 @@ def test_dataset_validation_invalid_extension(self, unique_name): with pytest.raises(ValueError, match="Unsupported file extension"): DataSet._validate_dataset_file("test.txt") + def test_create_dataset_with_invalid_format_s3(self, unique_name, test_bucket): + """Test creating dataset from S3 with invalid format fails.""" + # This would require an actual invalid file in S3, so we'll mock it + with patch('sagemaker.ai_registry.dataset.AIRHub.download_from_s3'), \ + patch('sagemaker.ai_registry.dataset.DataSet._validate_dataset_format', side_effect=ValueError("Invalid format")): + with pytest.raises(ValueError, match="Invalid format"): + DataSet.create( + name=unique_name, + source=f"s3://{test_bucket}/invalid_file.jsonl", + wait=False + ) + + def test_create_dataset_with_invalid_format_local(self, unique_name): + """Test creating dataset from local file with invalid format fails.""" + import tempfile + with tempfile.NamedTemporaryFile(suffix='.jsonl', mode='w', delete=False) as f: + f.write("invalid content") + f.flush() + try: + with pytest.raises(ValueError, match="Unable to detect format"): + DataSet.create( + name=unique_name, + source=f.name, + wait=False + ) + finally: + os.unlink(f.name) + def test_dataset_validation_large_file(self, unique_name): """Test dataset validation with oversized file.""" import tempfile @@ -156,3 +248,27 @@ def test_dataset_with_tags(self, unique_name, sample_jsonl_file, cleanup_list): cleanup_list.append(dataset) assert dataset.name == unique_name + def test_dataset_format_validation_success(self, unique_name, sample_jsonl_file): + """Test dataset format validation succeeds for valid files.""" + # Should not raise any exception for valid JSONL file + DataSet._validate_dataset_format(sample_jsonl_file) + + def test_dataset_format_validation_failure_invalid_format(self, unique_name): + """Test dataset format validation fails for invalid format.""" + import tempfile + with tempfile.NamedTemporaryFile(suffix='.jsonl', mode='w', delete=False) as f: + f.write("invalid json content") + f.flush() + with pytest.raises(ValueError, match="Unable to detect format"): + DataSet._validate_dataset_format(f.name) + os.unlink(f.name) + + def test_dataset_format_validation_failure_empty_file(self, unique_name): + """Test dataset format validation fails for empty files.""" + import tempfile + with tempfile.NamedTemporaryFile(suffix='.jsonl', delete=False) as f: + f.flush() # Create empty file + with pytest.raises(ValueError, match="Unable to detect format"): + DataSet._validate_dataset_format(f.name) + os.unlink(f.name) + diff --git a/sagemaker-train/tests/unit/ai_registry/test_dataset.py b/sagemaker-train/tests/unit/ai_registry/test_dataset.py index b2643928f0..16bb849e57 100644 --- a/sagemaker-train/tests/unit/ai_registry/test_dataset.py +++ b/sagemaker-train/tests/unit/ai_registry/test_dataset.py @@ -62,6 +62,28 @@ def test_validate_dataset_file_supported_extension(self): finally: os.unlink(temp_file) + @patch('sagemaker.ai_registry.dataset.DatasetFormatDetector') + def test_validate_dataset_format_success(self, mock_detector_class): + """Test dataset format validation succeeds when format is detected.""" + mock_detector = Mock() + mock_detector.validate_dataset.return_value = "jsonl" + mock_detector_class.return_value = mock_detector + + # Should not raise any exception + DataSet._validate_dataset_format("/path/to/file.jsonl") + mock_detector.validate_dataset.assert_called_once_with("/path/to/file.jsonl") + + @patch('sagemaker.ai_registry.dataset.DatasetFormatDetector') + def test_validate_dataset_format_failure(self, mock_detector_class): + """Test dataset format validation fails when format cannot be detected.""" + mock_detector = Mock() + mock_detector.validate_dataset.return_value = False + mock_detector_class.return_value = mock_detector + + with pytest.raises(ValueError, match="Unable to detect format for /path/to/file.jsonl"): + DataSet._validate_dataset_format("/path/to/file.jsonl") + mock_detector.validate_dataset.assert_called_once_with("/path/to/file.jsonl") + def test_validate_dataset_file_unsupported_extension(self): """Test validation fails for unsupported file extensions.""" # Test various unsupported extensions @@ -160,9 +182,9 @@ class TestDataSet: @patch('sagemaker.core.helper.session_helper.Session') @patch('sagemaker.train.common_utils.finetune_utils._get_current_domain_id') @patch('sagemaker.ai_registry.dataset.DataSet._validate_dataset_file') - @patch('sagemaker.ai_registry.dataset.validate_dataset') + @patch('sagemaker.ai_registry.dataset.DataSet._validate_dataset_format') @patch('sagemaker.ai_registry.dataset.AIRHub') - def test_create_with_s3_location(self, mock_air_hub, mock_validate, mock_validate_file, mock_get_domain_id, mock_session, mock_boto_client): + def test_create_with_s3_location(self, mock_air_hub, mock_validate_format, mock_validate_file, mock_get_domain_id, mock_session, mock_boto_client): mock_get_domain_id.return_value = None mock_session_instance = Mock() mock_session_instance.get_caller_identity_arn.return_value = "arn:aws:iam::123456789012:role/SageMakerRole" @@ -182,7 +204,7 @@ def test_create_with_s3_location(self, mock_air_hub, mock_validate, mock_validat "LastModifiedTime": "2024-01-01" } mock_air_hub.download_from_s3 = Mock() - mock_validate.return_value = None + mock_validate_format.return_value = None mock_validate_file.return_value = None def mock_exists(path): @@ -215,9 +237,9 @@ def mock_exists(path): @patch('sagemaker.core.helper.session_helper.Session') @patch('sagemaker.train.common_utils.finetune_utils._get_current_domain_id') @patch('sagemaker.ai_registry.dataset.DataSet._validate_dataset_file') - @patch('sagemaker.ai_registry.dataset.validate_dataset') + @patch('sagemaker.ai_registry.dataset.DataSet._validate_dataset_format') @patch('sagemaker.ai_registry.dataset.AIRHub') - def test_create_with_s3_location_preserves_full_path(self, mock_air_hub, mock_validate, mock_validate_file, mock_get_domain_id, mock_session, mock_boto_client): + def test_create_with_s3_location_preserves_full_path(self, mock_air_hub, mock_validate_format, mock_validate_file, mock_get_domain_id, mock_session, mock_boto_client): """Test that S3 path includes filename, not just directory.""" mock_get_domain_id.return_value = None mock_session_instance = Mock() @@ -238,7 +260,7 @@ def test_create_with_s3_location_preserves_full_path(self, mock_air_hub, mock_va "LastModifiedTime": "2024-01-01" } mock_air_hub.download_from_s3 = Mock() - mock_validate.return_value = None + mock_validate_format.return_value = None mock_validate_file.return_value = None def mock_exists(path): @@ -273,9 +295,9 @@ def mock_exists(path): assert document['DatasetS3Bucket'] == 'test-bucket' @patch('sagemaker.ai_registry.dataset.DataSet._validate_dataset_file') - @patch('sagemaker.ai_registry.dataset.validate_dataset') + @patch('sagemaker.ai_registry.dataset.DataSet._validate_dataset_format') @patch('sagemaker.ai_registry.dataset.AIRHub') - def test_create_with_local_file(self, mock_air_hub, mock_validate, mock_validate_file): + def test_create_with_local_file(self, mock_air_hub, mock_validate_format, mock_validate_file): mock_air_hub.upload_to_s3.return_value = "s3://bucket/path" mock_air_hub.import_hub_content.return_value = {"HubContentArn": "test-arn"} mock_air_hub.describe_hub_content.return_value = { @@ -284,7 +306,7 @@ def test_create_with_local_file(self, mock_air_hub, mock_validate, mock_validate "CreationTime": "2024-01-01", "LastModifiedTime": "2024-01-01" } - mock_validate.return_value = None + mock_validate_format.return_value = None mock_validate_file.return_value = None dataset = DataSet.create( @@ -297,7 +319,7 @@ def test_create_with_local_file(self, mock_air_hub, mock_validate, mock_validate assert dataset.name == "test-dataset" assert dataset.method == DataSetMethod.UPLOADED mock_air_hub.upload_to_s3.assert_called_once() - mock_validate.assert_called_once_with("/local/path/file.jsonl", "dpo") + mock_validate_format.assert_called_once_with("/local/path/file.jsonl") mock_validate_file.assert_called_once_with("/local/path/file.jsonl") @patch('sagemaker.ai_registry.dataset.AIRHub')