Skip to content
This repository was archived by the owner on Aug 11, 2020. It is now read-only.

Commit 17dea50

Browse files
committed
refactored - workspace handling extracted to module
1 parent 75644bc commit 17dea50

File tree

3 files changed

+139
-115
lines changed

3 files changed

+139
-115
lines changed

paperspace/commands/experiments.py

Lines changed: 10 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,21 @@
1-
import os
21
import pydoc
3-
import zipfile
4-
from collections import OrderedDict
52

6-
import click
7-
import progressbar
8-
import requests
93
import terminaltables
10-
from requests_toolbelt.multipart import encoder
114

125
from paperspace import logger, constants, client, config
136
from paperspace.commands import CommandBase
14-
from paperspace.exceptions import PresignedUrlUnreachableException, S3UploadFailedException, \
15-
PresignedUrlAccessDeniedException
7+
from paperspace.workspace import S3WorkspaceHandler
168
from paperspace.logger import log_response
179
from paperspace.utils import get_terminal_lines
1810

19-
# from clint.textui.progress import Bar as ProgressBar
20-
2111
experiments_api = client.API(config.CONFIG_EXPERIMENTS_HOST, headers=client.default_headers)
2212

2313

2414
class ExperimentCommand(CommandBase):
15+
def __init__(self, workspace_handler=None, **kwargs):
16+
super(ExperimentCommand, self).__init__(**kwargs)
17+
self._workspace_handler = workspace_handler or S3WorkspaceHandler(api=self.api, logger=self.logger)
18+
2519
def _log_create_experiment(self, response, success_msg_template, error_msg):
2620
if response.ok:
2721
j = response.json()
@@ -37,104 +31,9 @@ def _log_create_experiment(self, response, success_msg_template, error_msg):
3731

3832

3933
class CreateExperimentCommand(ExperimentCommand):
40-
def _retrieve_file_paths(self, dirName):
41-
42-
# setup file paths variable
43-
file_paths = {}
44-
exclude = ['.git', '.idea', '.pytest_cache']
45-
# Read all directory, subdirectories and file lists
46-
for root, dirs, files in os.walk(dirName, topdown=True):
47-
dirs[:] = [d for d in dirs if d not in exclude]
48-
for filename in files:
49-
# Create the full filepath by using os module.
50-
relpath = os.path.relpath(root, dirName)
51-
if relpath == '.':
52-
file_path = filename
53-
else:
54-
file_path = os.path.join(os.path.relpath(root, dirName), filename)
55-
file_paths[file_path] = os.path.join(root, filename)
56-
57-
# return all paths
58-
return file_paths
59-
60-
def _zip_workspace(self, workspace_path):
61-
if not workspace_path:
62-
workspace_path = '.'
63-
zip_file_name = os.path.basename(os.getcwd()) + '.zip'
64-
else:
65-
zip_file_name = os.path.basename(workspace_path) + '.zip'
66-
67-
zip_file_path = os.path.join(workspace_path, zip_file_name)
68-
69-
if os.path.exists(zip_file_path):
70-
self.logger.log('Removing existing archive')
71-
os.remove(zip_file_path)
72-
73-
file_paths = self._retrieve_file_paths(workspace_path)
74-
75-
self.logger.log('Creating zip archive: %s' % zip_file_name)
76-
zip_file = zipfile.ZipFile(zip_file_path, 'w')
77-
78-
bar = progressbar.ProgressBar(max_value=len(file_paths))
79-
80-
with zip_file:
81-
i = 0
82-
for relpath, abspath in file_paths.items():
83-
i+=1
84-
self.logger.debug('Adding %s to archive' % relpath)
85-
zip_file.write(abspath, arcname=relpath)
86-
bar.update(i)
87-
bar.finish()
88-
self.logger.log('\nFinished creating archive: %s' % zip_file_name)
89-
return zip_file_path
90-
91-
def _create_callback(self, encoder_obj):
92-
bar = progressbar.ProgressBar(max_value=encoder_obj.len)
93-
94-
def callback(monitor):
95-
bar.update(monitor.bytes_read)
96-
return callback
97-
98-
def _upload_workspace(self, input_data):
99-
workspace_url = input_data.get('workspaceUrl')
100-
workspace_path = input_data.get('workspacePath')
101-
workspace_archive = input_data.get('workspaceArchive')
102-
if (workspace_archive and workspace_path) or (workspace_archive and workspace_url) or (
103-
workspace_path and workspace_url):
104-
raise click.UsageError("Use either:\n\t--workspaceUrl to point repository URL"
105-
"\n\t--workspacePath to point on project directory"
106-
"\n\t--workspaceArchive to point on project ZIP archive"
107-
"\n or neither to use current directory")
108-
109-
if workspace_url:
110-
return # nothing to do
111-
112-
if workspace_archive:
113-
archive_path = os.path.abspath(workspace_archive)
114-
else:
115-
archive_path = self._zip_workspace(workspace_path)
116-
117-
file_name = os.path.basename(archive_path)
118-
s3_upload_data = self._get_upload_data(file_name)
119-
bucket_name = s3_upload_data['bucket_name']
120-
121-
self.logger.log('Uploading zipped workspace to S3')
122-
123-
files = {'file': (archive_path, open(archive_path, 'rb'))}
124-
fields = OrderedDict(s3_upload_data['fields'])
125-
fields.update(files)
126-
s3_encoder = encoder.MultipartEncoder(fields=fields)
127-
monitor = encoder.MultipartEncoderMonitor(s3_encoder, callback=self._create_callback(s3_encoder))
128-
s3_response = requests.post(s3_upload_data['url'], data=monitor, headers={'Content-Type': monitor.content_type})
129-
if not s3_response.ok:
130-
raise S3UploadFailedException(s3_response)
131-
132-
self.logger.log('\nUploading completed')
133-
134-
return 's3://{}/{}'.format(bucket_name, file_name)
13534

13635
def execute(self, json_):
137-
workspace_url = self._upload_workspace(json_)
36+
workspace_url = self._workspace_handler.upload_workspace(json_)
13837
if workspace_url:
13938
json_['workspaceUrl'] = workspace_url
14039

@@ -144,17 +43,13 @@ def execute(self, json_):
14443
"New experiment created with handle: {}",
14544
"Unknown error while creating the experiment")
14645

147-
def _get_upload_data(self, file_name):
148-
response = self.api.get("/workspace/get_presigned_url", params={'workspaceName': file_name})
149-
if response.status_code == 404:
150-
raise PresignedUrlUnreachableException
151-
if response.status_code == 403:
152-
raise PresignedUrlAccessDeniedException
153-
return response.json()
154-
15546

15647
class CreateAndStartExperimentCommand(ExperimentCommand):
15748
def execute(self, json_):
49+
workspace_url = self._workspace_handler.upload_workspace(json_)
50+
if workspace_url:
51+
json_['workspaceUrl'] = workspace_url
52+
15853
response = self.api.post("/experiments/create_and_start/", json=json_)
15954
self._log_create_experiment(response,
16055
"New experiment created and started with handle: {}",

paperspace/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,9 @@ class PresignedUrlAccessDeniedException(ApplicationException):
1414
pass
1515

1616

17+
class PresignedUrlConnectionException(ApplicationException):
18+
pass
19+
20+
1721
class S3UploadFailedException(ApplicationException):
1822
pass

paperspace/workspace.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import logging
2+
import os
3+
import zipfile
4+
from collections import OrderedDict
5+
6+
import click
7+
import progressbar
8+
import requests
9+
from requests_toolbelt.multipart import encoder
10+
11+
from paperspace.exceptions import S3UploadFailedException, PresignedUrlUnreachableException, \
12+
PresignedUrlAccessDeniedException, PresignedUrlConnectionException
13+
14+
15+
class S3WorkspaceHandler:
16+
def __init__(self, api, logger=None):
17+
self.api = api
18+
self.logger = logger or logging.getLogger()
19+
20+
def _retrieve_file_paths(self, dirName):
21+
22+
# setup file paths variable
23+
file_paths = {}
24+
exclude = ['.git', '.idea', '.pytest_cache']
25+
# Read all directory, subdirectories and file lists
26+
for root, dirs, files in os.walk(dirName, topdown=True):
27+
dirs[:] = [d for d in dirs if d not in exclude]
28+
for filename in files:
29+
# Create the full filepath by using os module.
30+
relpath = os.path.relpath(root, dirName)
31+
if relpath == '.':
32+
file_path = filename
33+
else:
34+
file_path = os.path.join(os.path.relpath(root, dirName), filename)
35+
file_paths[file_path] = os.path.join(root, filename)
36+
37+
# return all paths
38+
return file_paths
39+
40+
def _zip_workspace(self, workspace_path):
41+
if not workspace_path:
42+
workspace_path = '.'
43+
zip_file_name = os.path.basename(os.getcwd()) + '.zip'
44+
else:
45+
zip_file_name = os.path.basename(workspace_path) + '.zip'
46+
47+
zip_file_path = os.path.join(workspace_path, zip_file_name)
48+
49+
if os.path.exists(zip_file_path):
50+
self.logger.log('Removing existing archive')
51+
os.remove(zip_file_path)
52+
53+
file_paths = self._retrieve_file_paths(workspace_path)
54+
55+
self.logger.log('Creating zip archive: %s' % zip_file_name)
56+
zip_file = zipfile.ZipFile(zip_file_path, 'w')
57+
58+
bar = progressbar.ProgressBar(max_value=len(file_paths))
59+
60+
with zip_file:
61+
i = 0
62+
for relpath, abspath in file_paths.items():
63+
i += 1
64+
self.logger.debug('Adding %s to archive' % relpath)
65+
zip_file.write(abspath, arcname=relpath)
66+
bar.update(i)
67+
bar.finish()
68+
self.logger.log('\nFinished creating archive: %s' % zip_file_name)
69+
return zip_file_path
70+
71+
def _create_callback(self, encoder_obj):
72+
bar = progressbar.ProgressBar(max_value=encoder_obj.len)
73+
74+
def callback(monitor):
75+
bar.update(monitor.bytes_read)
76+
77+
return callback
78+
79+
def upload_workspace(self, input_data):
80+
workspace_url = input_data.get('workspaceUrl')
81+
workspace_path = input_data.get('workspacePath')
82+
workspace_archive = input_data.get('workspaceArchive')
83+
if (workspace_archive and workspace_path) or (workspace_archive and workspace_url) or (
84+
workspace_path and workspace_url):
85+
raise click.UsageError("Use either:\n\t--workspaceUrl to point repository URL"
86+
"\n\t--workspacePath to point on project directory"
87+
"\n\t--workspaceArchive to point on project ZIP archive"
88+
"\n or neither to use current directory")
89+
90+
if workspace_url:
91+
return # nothing to do
92+
93+
if workspace_archive:
94+
archive_path = os.path.abspath(workspace_archive)
95+
else:
96+
archive_path = self._zip_workspace(workspace_path)
97+
98+
file_name = os.path.basename(archive_path)
99+
s3_upload_data = self._get_upload_data(file_name)
100+
bucket_name = s3_upload_data['bucket_name']
101+
102+
self.logger.log('Uploading zipped workspace to S3')
103+
104+
files = {'file': (archive_path, open(archive_path, 'rb'))}
105+
fields = OrderedDict(s3_upload_data['fields'])
106+
fields.update(files)
107+
s3_encoder = encoder.MultipartEncoder(fields=fields)
108+
monitor = encoder.MultipartEncoderMonitor(s3_encoder, callback=self._create_callback(s3_encoder))
109+
s3_response = requests.post(s3_upload_data['url'], data=monitor, headers={'Content-Type': monitor.content_type})
110+
if not s3_response.ok:
111+
raise S3UploadFailedException(s3_response)
112+
113+
self.logger.log('\nUploading completed')
114+
115+
return 's3://{}/{}'.format(bucket_name, file_name)
116+
117+
def _get_upload_data(self, file_name):
118+
response = self.api.get("/workspace/get_presigned_url", params={'workspaceName': file_name})
119+
if response.status_code == 404:
120+
raise PresignedUrlUnreachableException
121+
if response.status_code == 403:
122+
raise PresignedUrlAccessDeniedException
123+
if not response.ok:
124+
raise PresignedUrlConnectionException(response.reason)
125+
return response.json()

0 commit comments

Comments
 (0)