Skip to content
This repository was archived by the owner on Aug 11, 2020. It is now read-only.

Commit 04295c1

Browse files
committed
relative path added to zip,
progress bar tweak (it works!)
1 parent dbcc3db commit 04295c1

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

paperspace/commands/experiments.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import pydoc
33
import zipfile
4+
from collections import OrderedDict
45

56
import click
67
import progressbar
@@ -10,7 +11,8 @@
1011

1112
from paperspace import logger, constants, client, config
1213
from paperspace.commands import CommandBase
13-
from paperspace.exceptions import PresignedUrlUnreachableException, S3UploadFailedException
14+
from paperspace.exceptions import PresignedUrlUnreachableException, S3UploadFailedException, \
15+
PresignedUrlAccessDeniedException
1416
from paperspace.logger import log_response
1517
from paperspace.utils import get_terminal_lines
1618

@@ -35,21 +37,25 @@ def _log_create_experiment(self, response, success_msg_template, error_msg):
3537

3638

3739
class CreateExperimentCommand(ExperimentCommand):
38-
def retrieve_file_paths(self, dirName):
40+
def _retrieve_file_paths(self, dirName):
3941

4042
# setup file paths variable
41-
filePaths = []
42-
exclude = ['.git']
43+
file_paths = {}
44+
exclude = ['.git', '.idea', '.pytest_cache']
4345
# Read all directory, subdirectories and file lists
4446
for root, dirs, files in os.walk(dirName, topdown=True):
4547
dirs[:] = [d for d in dirs if d not in exclude]
4648
for filename in files:
4749
# Create the full filepath by using os module.
48-
filePath = os.path.join(root, filename)
49-
filePaths.append(filePath)
50+
relpath = os.path.relpath(root, dirName)
51+
if relpath == '.':
52+
file_path = filename
53+
else:
54+
file_path = os.path.join(os.path.relpath(root, dirName), filename)
55+
file_paths[file_path] = os.path.join(root, filename)
5056

5157
# return all paths
52-
return filePaths
58+
return file_paths
5359

5460
def _zip_workspace(self, workspace_path):
5561
if not workspace_path:
@@ -64,7 +70,7 @@ def _zip_workspace(self, workspace_path):
6470
self.logger.log('Removing existing archive')
6571
os.remove(zip_file_path)
6672

67-
file_paths = self.retrieve_file_paths(workspace_path)
73+
file_paths = self._retrieve_file_paths(workspace_path)
6874

6975
self.logger.log('Creating zip archive: %s' % zip_file_name)
7076
zip_file = zipfile.ZipFile(zip_file_path, 'w')
@@ -73,10 +79,10 @@ def _zip_workspace(self, workspace_path):
7379

7480
with zip_file:
7581
i = 0
76-
for file in file_paths:
82+
for relpath, abspath in file_paths.items():
7783
i+=1
78-
self.logger.debug('Adding %s to archive' % file)
79-
zip_file.write(file)
84+
self.logger.debug('Adding %s to archive' % relpath)
85+
zip_file.write(abspath, arcname=relpath)
8086
bar.update(i)
8187
bar.finish()
8288
self.logger.log('\nFinished creating archive: %s' % zip_file_name)
@@ -87,8 +93,6 @@ def _create_callback(self, encoder_obj):
8793

8894
def callback(monitor):
8995
bar.update(monitor.bytes_read)
90-
if monitor.bytes_read == monitor.len:
91-
bar.finish()
9296
return callback
9397

9498
def _upload_workspace(self, input_data):
@@ -110,23 +114,24 @@ def _upload_workspace(self, input_data):
110114
else:
111115
archive_path = self._zip_workspace(workspace_path)
112116

113-
s3_upload_data = self._get_upload_data(os.path.basename(archive_path))
117+
file_name = os.path.basename(archive_path)
118+
s3_upload_data = self._get_upload_data(file_name)
119+
bucket_name = s3_upload_data['bucket_name']
114120

115121
self.logger.log('Uploading zipped workspace to S3')
116122

117123
files = {'file': (archive_path, open(archive_path, 'rb'))}
118-
fields = s3_upload_data['fields']
124+
fields = OrderedDict(s3_upload_data['fields'])
119125
fields.update(files)
120-
121126
s3_encoder = encoder.MultipartEncoder(fields=fields)
122127
monitor = encoder.MultipartEncoderMonitor(s3_encoder, callback=self._create_callback(s3_encoder))
123128
s3_response = requests.post(s3_upload_data['url'], data=monitor, headers={'Content-Type': monitor.content_type})
124129
if not s3_response.ok:
125130
raise S3UploadFailedException(s3_response)
126131

127132
self.logger.log('\nUploading completed')
128-
s3_workspace_url = s3_response.headers.get('Location')
129-
return s3_workspace_url
133+
134+
return 's3://{}/{}'.format(bucket_name, file_name)
130135

131136
def execute(self, json_):
132137
workspace_url = self._upload_workspace(json_)
@@ -143,6 +148,8 @@ def _get_upload_data(self, file_name):
143148
response = self.api.get("/workspace/get_presigned_url", params={'workspaceName': file_name})
144149
if response.status_code == 404:
145150
raise PresignedUrlUnreachableException
151+
if response.status_code == 403:
152+
raise PresignedUrlAccessDeniedException
146153
return response.json()
147154

148155

paperspace/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,9 @@ class PresignedUrlUnreachableException(ApplicationException):
1010
pass
1111

1212

13+
class PresignedUrlAccessDeniedException(ApplicationException):
14+
pass
15+
16+
1317
class S3UploadFailedException(ApplicationException):
1418
pass

0 commit comments

Comments
 (0)