Skip to content

Commit 3b07b4a

Browse files
committed
Add ignore_patterns in ModelTrainer to ignore specific files/folders
For commit: 829030a
1 parent c3871ff commit 3b07b4a

File tree

4 files changed

+69
-9
lines changed

4 files changed

+69
-9
lines changed

sagemaker-core/src/sagemaker/core/modules/configs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,23 @@ class SourceCode(BaseConfig):
9999
command (Optional[str]):
100100
The command(s) to execute in the training job container. Example: "python my_script.py".
101101
If not specified, entry_script must be provided.
102+
ignore_patterns: (Optional[List[str]]) :
103+
The ignore patterns to ignore specific files/folders when uploading to S3. If not specified,
104+
default to: ['.env', '.git', '__pycache__', '.DS_Store', '.cache', '.ipynb_checkpoints'].
102105
"""
103106

104107
source_dir: Optional[str] = None
105108
requirements: Optional[str] = None
106109
entry_script: Optional[str] = None
107110
command: Optional[str] = None
111+
ignore_patterns: Optional[List[str]] = [
112+
".env",
113+
".git",
114+
"__pycache__",
115+
".DS_Store",
116+
".cache",
117+
".ipynb_checkpoints",
118+
]
108119

109120

110121
class Compute(shapes.ResourceConfig):

sagemaker-core/src/sagemaker/core/training/configs.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from __future__ import absolute_import
2323

24-
from typing import Optional, Union
24+
from typing import Optional, Union, List
2525
from pydantic import BaseModel, model_validator, ConfigDict
2626

2727
import sagemaker.core.shapes as shapes
@@ -106,13 +106,23 @@ class SourceCode(BaseConfig):
106106
command (Optional[StrPipeVar]):
107107
The command(s) to execute in the training job container. Example: "python my_script.py".
108108
If not specified, entry_script must be provided.
109+
ignore_patterns: (Optional[List[str]]) :
110+
The ignore patterns to ignore specific files/folders when uploading to S3. If not specified,
111+
default to: ['.env', '.git', '__pycache__', '.DS_Store', '.cache', '.ipynb_checkpoints'].
109112
"""
110113

111114
source_dir: Optional[StrPipeVar] = None
112115
requirements: Optional[StrPipeVar] = None
113116
entry_script: Optional[StrPipeVar] = None
114117
command: Optional[StrPipeVar] = None
115-
118+
ignore_patterns: Optional[List[str]] = [
119+
".env",
120+
".git",
121+
"__pycache__",
122+
".DS_Store",
123+
".cache",
124+
".ipynb_checkpoints",
125+
]
116126

117127
class OutputDataConfig(shapes.OutputDataConfig):
118128
"""OutputDataConfig.

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ class ModelTrainer(BaseModel):
125125
from sagemaker.train import ModelTrainer
126126
from sagemaker.train.configs import SourceCode, Compute, InputData
127127
128-
source_code = SourceCode(source_dir="source", entry_script="train.py")
128+
ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data']
129+
source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
129130
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
130131
model_trainer = ModelTrainer(
131132
training_image=training_image,
@@ -612,6 +613,7 @@ def train(
612613
channel_name=SM_CODE,
613614
data_source=self.source_code.source_dir,
614615
key_prefix=input_data_key_prefix,
616+
ignore_patterns=self.source_code.ignore_patterns,
615617
)
616618
final_input_data_config.append(source_code_channel)
617619

@@ -633,6 +635,7 @@ def train(
633635
channel_name=SM_DRIVERS,
634636
data_source=tmp_dir.name,
635637
key_prefix=input_data_key_prefix,
638+
ignore_patterns=self.source_code.ignore_patterns,
636639
)
637640
final_input_data_config.append(sm_drivers_channel)
638641

@@ -742,7 +745,11 @@ def train(
742745
local_container.train(wait)
743746

744747
def create_input_data_channel(
745-
self, channel_name: str, data_source: DataSourceType, key_prefix: Optional[str] = None
748+
self,
749+
channel_name: str,
750+
data_source: DataSourceType,
751+
key_prefix: Optional[str] = None,
752+
ignore_patterns: Optional[List[str]] = None,
746753
) -> Channel:
747754
"""Create an input data channel for the training job.
748755
@@ -758,6 +765,9 @@ def create_input_data_channel(
758765
759766
If specified, local data will be uploaded to:
760767
``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
768+
ignore_patterns: (Optional[List[str]]) :
769+
The ignore patterns to ignore specific files/folders when uploading to S3.
770+
If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store', '.cache', '.ipynb_checkpoints'].
761771
"""
762772
from sagemaker.core.helper.pipeline_variable import PipelineVariable
763773

@@ -807,11 +817,28 @@ def create_input_data_channel(
807817
)
808818
if self.sagemaker_session.default_bucket_prefix:
809819
key_prefix = f"{self.sagemaker_session.default_bucket_prefix}/{key_prefix}"
810-
s3_uri = self.sagemaker_session.upload_data(
811-
path=data_source,
812-
bucket=self.sagemaker_session.default_bucket(),
813-
key_prefix=key_prefix,
814-
)
820+
if ignore_patterns and _is_valid_path(data_source, path_type="Directory"):
821+
tmp_dir = TemporaryDirectory()
822+
copied_path = os.path.join(
823+
tmp_dir.name, os.path.basename(os.path.normpath(data_source))
824+
)
825+
shutil.copytree(
826+
data_source,
827+
copied_path,
828+
dirs_exist_ok=True,
829+
ignore=shutil.ignore_patterns(*ignore_patterns),
830+
)
831+
s3_uri = self.sagemaker_session.upload_data(
832+
path=copied_path,
833+
bucket=self.sagemaker_session.default_bucket(),
834+
key_prefix=key_prefix,
835+
)
836+
else:
837+
s3_uri = self.sagemaker_session.upload_data(
838+
path=data_source,
839+
bucket=self.sagemaker_session.default_bucket(),
840+
key_prefix=key_prefix,
841+
)
815842
channel = Channel(
816843
channel_name=channel_name,
817844
data_source=DataSource(

sagemaker-train/tests/unit/train/test_model_trainer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,17 @@ def model_trainer():
202202
},
203203
"should_throw": False,
204204
},
205+
{
206+
"init_params": {
207+
"training_image": DEFAULT_IMAGE,
208+
"source_code": SourceCode(
209+
source_dir=DEFAULT_SOURCE_DIR,
210+
command="python custom_script.py",
211+
ignore_patterns=["data"],
212+
),
213+
},
214+
"should_throw": False,
215+
},
205216
],
206217
ids=[
207218
"no_params",
@@ -213,6 +224,7 @@ def model_trainer():
213224
"supported_source_code_local_tar_file",
214225
"supported_source_code_s3_dir",
215226
"supported_source_code_s3_tar_file",
227+
"supported_source_code_ignore_patterns",
216228
],
217229
)
218230
def test_model_trainer_param_validation(test_case, modules_session):

0 commit comments

Comments
 (0)