Skip to content

Commit 2cdb2d4

Browse files
committed
fixing unit tests and adding integration test
1 parent 1d6c559 commit 2cdb2d4

File tree

16 files changed

+1492
-0
lines changed

16 files changed

+1492
-0
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# flake8: noqa
2+
import argparse
3+
import numpy as np
4+
import os
5+
import sys
6+
import logging
7+
import json
8+
import shutil
9+
import torch
10+
import torch.nn as nn
11+
from torch.utils.data import DataLoader, TensorDataset
12+
from pytorch_model_def import get_model
13+
14+
logger = logging.getLogger(__name__)
15+
logger.setLevel(logging.DEBUG)
16+
logger.addHandler(logging.StreamHandler(sys.stdout))
17+
current_dir = os.path.dirname(os.path.abspath(__file__))
18+
19+
20+
def get_train_data(train_dir):
21+
"""
22+
Get the training data and convert to tensors
23+
"""
24+
25+
x_train = np.load(os.path.join(train_dir, "x_train.npy"))
26+
y_train = np.load(os.path.join(train_dir, "y_train.npy"))
27+
logger.info(f"x train: {x_train.shape}, y train: {y_train.shape}")
28+
29+
return torch.from_numpy(x_train), torch.from_numpy(y_train)
30+
31+
32+
def get_test_data(test_dir):
33+
"""
34+
Get the testing data and convert to tensors
35+
"""
36+
37+
x_test = np.load(os.path.join(test_dir, "x_test.npy"))
38+
y_test = np.load(os.path.join(test_dir, "y_test.npy"))
39+
logger.info(f"x test: {x_test.shape}, y test: {y_test.shape}")
40+
41+
return torch.from_numpy(x_test), torch.from_numpy(y_test)
42+
43+
44+
def model_fn(model_dir):
45+
"""
46+
Load the model for inference
47+
"""
48+
49+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50+
model = get_model()
51+
model.load_state_dict(torch.load(model_dir + "/model.pth"))
52+
model.eval()
53+
return model.to(device)
54+
55+
56+
def input_fn(request_body, request_content_type):
57+
"""
58+
Deserialize and prepare the prediction input
59+
"""
60+
61+
if request_content_type == "application/json":
62+
request = json.loads(request_body)
63+
train_inputs = torch.tensor(request)
64+
return train_inputs
65+
66+
67+
def predict_fn(input_data, model):
68+
"""
69+
Apply model to the incoming request
70+
"""
71+
72+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73+
model.to(device)
74+
model.eval()
75+
with torch.no_grad():
76+
return model(input_data.float()).numpy()[0]
77+
78+
79+
def parse_args():
80+
"""
81+
Parse the command line arguments
82+
"""
83+
84+
parser = argparse.ArgumentParser()
85+
parser.add_argument(
86+
"--model-dir",
87+
type=str,
88+
default=os.environ.get("SM_MODEL_DIR", os.path.join(current_dir, "data/model")),
89+
help="Directory to save the model",
90+
)
91+
parser.add_argument(
92+
"--train-dir",
93+
type=str,
94+
default=os.environ.get("SM_CHANNEL_TRAIN", os.path.join(current_dir, "data/train")),
95+
help="Directory containing training data",
96+
)
97+
parser.add_argument(
98+
"--test-dir",
99+
type=str,
100+
default=os.environ.get("SM_CHANNEL_TEST", os.path.join(current_dir, "data/test")),
101+
help="Directory containing testing data",
102+
)
103+
parser.add_argument(
104+
"--batch-size",
105+
type=int,
106+
default=64,
107+
help="Batch size for training",
108+
)
109+
parser.add_argument(
110+
"--epochs",
111+
type=int,
112+
default=1,
113+
help="Number of epochs for training",
114+
)
115+
parser.add_argument(
116+
"--learning-rate",
117+
type=float,
118+
default=0.1,
119+
help="Learning rate for training",
120+
)
121+
return parser.parse_args()
122+
123+
124+
def train():
125+
"""
126+
Train the PyTorch model
127+
"""
128+
args = parse_args()
129+
# Directories: train, test and model
130+
train_dir = args.train_dir
131+
test_dir = args.test_dir
132+
model_dir = args.model_dir
133+
134+
# Load the training and testing data
135+
x_train, y_train = get_train_data(train_dir)
136+
x_test, y_test = get_test_data(test_dir)
137+
train_ds = TensorDataset(x_train, y_train)
138+
139+
# Training parameters - used to configure the training loop
140+
batch_size = args.batch_size
141+
epochs = args.epochs
142+
learning_rate = args.learning_rate
143+
logger.info(
144+
"batch_size = {}, epochs = {}, learning rate = {}".format(batch_size, epochs, learning_rate)
145+
)
146+
147+
train_dl = DataLoader(train_ds, batch_size, shuffle=True)
148+
149+
# Define the model, loss function and optimizer
150+
model = get_model()
151+
model = model.to(device)
152+
criterion = nn.MSELoss()
153+
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
154+
155+
# Train the model
156+
for epoch in range(epochs):
157+
for x_train_batch, y_train_batch in train_dl:
158+
y = model(x_train_batch.float())
159+
loss = criterion(y.flatten(), y_train_batch.float())
160+
optimizer.zero_grad()
161+
loss.backward()
162+
optimizer.step()
163+
epoch += 1
164+
logger.info(f"epoch: {epoch} -> loss: {loss}")
165+
166+
# Test the model
167+
with torch.no_grad():
168+
y = model(x_test.float()).flatten()
169+
mse = ((y - y_test) ** 2).sum() / y_test.shape[0]
170+
print("\nTest MSE:", mse.numpy())
171+
172+
# Save the model
173+
os.makedirs(model_dir, exist_ok=True)
174+
torch.save(model.state_dict(), model_dir + "/model.pth")
175+
inference_code_path = model_dir + "/code/"
176+
177+
if not os.path.exists(inference_code_path):
178+
os.mkdir(inference_code_path)
179+
logger.info("Created a folder at {}!".format(inference_code_path))
180+
181+
shutil.copy("custom_script.py", inference_code_path)
182+
shutil.copy("pytorch_model_def.py", inference_code_path)
183+
logger.info("Saving models files to {}".format(inference_code_path))
184+
185+
186+
if __name__ == "__main__":
187+
print("Running the training job ...\n")
188+
189+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
190+
191+
train()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# flake8: noqa
2+
import torch
3+
import torch.nn as nn
4+
5+
6+
class NeuralNet(nn.Module):
7+
def __init__(self):
8+
super().__init__()
9+
self.fc1 = nn.Linear(8, 8)
10+
self.fc2 = nn.Linear(8, 6)
11+
self.fc3 = nn.Linear(6, 1)
12+
13+
def forward(self, x):
14+
x = torch.tanh(self.fc1(x))
15+
x = torch.sigmoid(self.fc2(x))
16+
x = self.fc3(x)
17+
return x
18+
19+
20+
def get_model():
21+
22+
model = NeuralNet()
23+
return model
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
numpy
2+
-f https://download.pytorch.org/whl/torch_stable.html
3+
torch==2.7.0

sagemaker-train/tests/integ/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains the Integ Tests for SageMaker PySDK Training."""
1414
from __future__ import absolute_import
15+
16+
import os
17+
18+
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""AWS Batch integration tests"""
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import time
16+
17+
18+
class BatchTestResourceManager:
19+
20+
def __init__(
21+
self,
22+
batch_client,
23+
queue_name="pysdk-test-queue",
24+
service_env_name="pysdk-test-queue-service-environment",
25+
):
26+
self.batch_client = batch_client
27+
self.queue_name = queue_name
28+
self.service_environment_name = service_env_name
29+
30+
def _create_or_get_service_environment(self, service_environment_name):
31+
print(f"Creating service environment: {service_environment_name}")
32+
try:
33+
response = self.batch_client.create_service_environment(
34+
serviceEnvironmentName=service_environment_name,
35+
serviceEnvironmentType="SAGEMAKER_TRAINING",
36+
capacityLimits=[{"maxCapacity": 10, "capacityUnit": "NUM_INSTANCES"}],
37+
)
38+
print(f"Service environment {service_environment_name} created successfully.")
39+
return response
40+
except Exception as e:
41+
if "Object already exists" in str(e):
42+
print("Resource already exists. Fetching existing resource.")
43+
response = self.batch_client.describe_service_environments(
44+
serviceEnvironments=[service_environment_name]
45+
)
46+
return response["serviceEnvironments"][0]
47+
else:
48+
print(f"Error creating service environment: {e}")
49+
raise
50+
51+
def _create_or_get_queue(self, queue_name, service_environment_arn):
52+
53+
print(f"Creating job queue: {queue_name}")
54+
try:
55+
response = self.batch_client.create_job_queue(
56+
jobQueueName=queue_name,
57+
priority=1,
58+
computeEnvironmentOrder=[],
59+
serviceEnvironmentOrder=[
60+
{
61+
"order": 1,
62+
"serviceEnvironment": service_environment_arn,
63+
},
64+
],
65+
jobQueueType="SAGEMAKER_TRAINING",
66+
)
67+
print(f"Job queue {queue_name} created successfully.")
68+
return response
69+
except Exception as e:
70+
if "Object already exists" in str(e):
71+
print("Resource already exists. Fetching existing resource.")
72+
response = self.batch_client.describe_job_queues(jobQueues=[queue_name])
73+
return response["jobQueues"][0]
74+
else:
75+
print(f"Error creating job queue: {e}")
76+
raise
77+
78+
def _update_queue_state(self, queue_name, state):
79+
try:
80+
print(f"Updating queue {queue_name} to state {state}")
81+
response = self.batch_client.update_job_queue(jobQueue=queue_name, state=state)
82+
return response
83+
except Exception as e:
84+
print(f"Error updating queue: {e}")
85+
86+
def _update_service_environment_state(self, service_environment_name, state):
87+
print(f"Updating service environment {service_environment_name} to state {state}")
88+
try:
89+
response = self.batch_client.update_service_environment(
90+
serviceEnvironment=service_environment_name, state=state
91+
)
92+
return response
93+
except Exception as e:
94+
print(f"Error updating service environment: {e}")
95+
96+
def _wait_for_queue_state(self, queue_name, state):
97+
print(f"Waiting for queue {queue_name} to be {state}...")
98+
while True:
99+
response = self.batch_client.describe_job_queues(jobQueues=[queue_name])
100+
print(f"Current state: {response}")
101+
if response["jobQueues"][0]["state"] == state:
102+
break
103+
time.sleep(5)
104+
print(f"Queue {queue_name} is now {state}.")
105+
106+
def _wait_for_service_environment_state(self, service_environment_name, state):
107+
print(f"Waiting for service environment {service_environment_name} to be {state}...")
108+
while True:
109+
response = self.batch_client.describe_service_environments(
110+
serviceEnvironments=[service_environment_name]
111+
)
112+
print(f"Current state: {response}")
113+
if response["serviceEnvironments"][0]["state"] == state:
114+
break
115+
time.sleep(5)
116+
print(f"Service environment {service_environment_name} is now {state}.")
117+
118+
def get_or_create_resources(self, queue_name=None, service_environment_name=None):
119+
queue_name = queue_name or self.queue_name
120+
service_environment_name = service_environment_name or self.service_environment_name
121+
122+
service_environment = self._create_or_get_service_environment(service_environment_name)
123+
if service_environment.get("state") != "ENABLED":
124+
self._update_service_environment_state(service_environment_name, "ENABLED")
125+
self._wait_for_service_environment_state(service_environment_name, "ENABLED")
126+
time.sleep(10)
127+
128+
queue = self._create_or_get_queue(queue_name, service_environment["serviceEnvironmentArn"])
129+
if queue.get("state") != "ENABLED":
130+
self._update_queue_state(queue_name, "ENABLED")
131+
self._wait_for_queue_state(queue_name, "ENABLED")
132+
time.sleep(10)
133+
return queue, service_environment

0 commit comments

Comments
 (0)