11import os
22import pydoc
33import zipfile
4+ from collections import OrderedDict
45
56import click
67import progressbar
1011
1112from paperspace import logger , constants , client , config
1213from paperspace .commands import CommandBase
13- from paperspace .exceptions import PresignedUrlUnreachableException , S3UploadFailedException
14+ from paperspace .exceptions import PresignedUrlUnreachableException , S3UploadFailedException , \
15+ PresignedUrlAccessDeniedException
1416from paperspace .logger import log_response
1517from paperspace .utils import get_terminal_lines
1618
@@ -35,21 +37,25 @@ def _log_create_experiment(self, response, success_msg_template, error_msg):
3537
3638
3739class CreateExperimentCommand (ExperimentCommand ):
38- def retrieve_file_paths (self , dirName ):
40+ def _retrieve_file_paths (self , dirName ):
3941
4042 # setup file paths variable
41- filePaths = []
42- exclude = ['.git' ]
43+ file_paths = {}
44+ exclude = ['.git' , '.idea' , '.pytest_cache' ]
4345 # Read all directory, subdirectories and file lists
4446 for root , dirs , files in os .walk (dirName , topdown = True ):
4547 dirs [:] = [d for d in dirs if d not in exclude ]
4648 for filename in files :
4749 # Create the full filepath by using os module.
48- filePath = os .path .join (root , filename )
49- filePaths .append (filePath )
50+ relpath = os .path .relpath (root , dirName )
51+ if relpath == '.' :
52+ file_path = filename
53+ else :
54+ file_path = os .path .join (os .path .relpath (root , dirName ), filename )
55+ file_paths [file_path ] = os .path .join (root , filename )
5056
5157 # return all paths
52- return filePaths
58+ return file_paths
5359
5460 def _zip_workspace (self , workspace_path ):
5561 if not workspace_path :
@@ -64,7 +70,7 @@ def _zip_workspace(self, workspace_path):
6470 self .logger .log ('Removing existing archive' )
6571 os .remove (zip_file_path )
6672
67- file_paths = self .retrieve_file_paths (workspace_path )
73+ file_paths = self ._retrieve_file_paths (workspace_path )
6874
6975 self .logger .log ('Creating zip archive: %s' % zip_file_name )
7076 zip_file = zipfile .ZipFile (zip_file_path , 'w' )
@@ -73,10 +79,10 @@ def _zip_workspace(self, workspace_path):
7379
7480 with zip_file :
7581 i = 0
76- for file in file_paths :
82+ for relpath , abspath in file_paths . items () :
7783 i += 1
78- self .logger .debug ('Adding %s to archive' % file )
79- zip_file .write (file )
84+ self .logger .debug ('Adding %s to archive' % relpath )
85+ zip_file .write (abspath , arcname = relpath )
8086 bar .update (i )
8187 bar .finish ()
8288 self .logger .log ('\n Finished creating archive: %s' % zip_file_name )
@@ -87,8 +93,6 @@ def _create_callback(self, encoder_obj):
8793
8894 def callback (monitor ):
8995 bar .update (monitor .bytes_read )
90- if monitor .bytes_read == monitor .len :
91- bar .finish ()
9296 return callback
9397
9498 def _upload_workspace (self , input_data ):
@@ -110,23 +114,24 @@ def _upload_workspace(self, input_data):
110114 else :
111115 archive_path = self ._zip_workspace (workspace_path )
112116
113- s3_upload_data = self ._get_upload_data (os .path .basename (archive_path ))
117+ file_name = os .path .basename (archive_path )
118+ s3_upload_data = self ._get_upload_data (file_name )
119+ bucket_name = s3_upload_data ['bucket_name' ]
114120
115121 self .logger .log ('Uploading zipped workspace to S3' )
116122
117123 files = {'file' : (archive_path , open (archive_path , 'rb' ))}
118- fields = s3_upload_data ['fields' ]
124+ fields = OrderedDict ( s3_upload_data ['fields' ])
119125 fields .update (files )
120-
121126 s3_encoder = encoder .MultipartEncoder (fields = fields )
122127 monitor = encoder .MultipartEncoderMonitor (s3_encoder , callback = self ._create_callback (s3_encoder ))
123128 s3_response = requests .post (s3_upload_data ['url' ], data = monitor , headers = {'Content-Type' : monitor .content_type })
124129 if not s3_response .ok :
125130 raise S3UploadFailedException (s3_response )
126131
127132 self .logger .log ('\n Uploading completed' )
128- s3_workspace_url = s3_response . headers . get ( 'Location' )
129- return s3_workspace_url
133+
134+ return 's3://{}/{}' . format ( bucket_name , file_name )
130135
131136 def execute (self , json_ ):
132137 workspace_url = self ._upload_workspace (json_ )
@@ -143,6 +148,8 @@ def _get_upload_data(self, file_name):
143148 response = self .api .get ("/workspace/get_presigned_url" , params = {'workspaceName' : file_name })
144149 if response .status_code == 404 :
145150 raise PresignedUrlUnreachableException
151+ if response .status_code == 403 :
152+ raise PresignedUrlAccessDeniedException
146153 return response .json ()
147154
148155
0 commit comments