1- import os
21import pydoc
3- import zipfile
4- from collections import OrderedDict
52
6- import click
7- import progressbar
8- import requests
93import terminaltables
10- from requests_toolbelt .multipart import encoder
114
125from paperspace import logger , constants , client , config
136from paperspace .commands import CommandBase
14- from paperspace .exceptions import PresignedUrlUnreachableException , S3UploadFailedException , \
15- PresignedUrlAccessDeniedException
7+ from paperspace .workspace import S3WorkspaceHandler
168from paperspace .logger import log_response
179from paperspace .utils import get_terminal_lines
1810
19- # from clint.textui.progress import Bar as ProgressBar
20-
2111experiments_api = client .API (config .CONFIG_EXPERIMENTS_HOST , headers = client .default_headers )
2212
2313
2414class ExperimentCommand (CommandBase ):
15+ def __init__ (self , workspace_handler = None , ** kwargs ):
16+ super (ExperimentCommand , self ).__init__ (** kwargs )
17+ self ._workspace_handler = workspace_handler or S3WorkspaceHandler (api = self .api , logger = self .logger )
18+
2519 def _log_create_experiment (self , response , success_msg_template , error_msg ):
2620 if response .ok :
2721 j = response .json ()
@@ -37,104 +31,9 @@ def _log_create_experiment(self, response, success_msg_template, error_msg):
3731
3832
3933class CreateExperimentCommand (ExperimentCommand ):
40- def _retrieve_file_paths (self , dirName ):
41-
42- # setup file paths variable
43- file_paths = {}
44- exclude = ['.git' , '.idea' , '.pytest_cache' ]
45- # Read all directory, subdirectories and file lists
46- for root , dirs , files in os .walk (dirName , topdown = True ):
47- dirs [:] = [d for d in dirs if d not in exclude ]
48- for filename in files :
49- # Create the full filepath by using os module.
50- relpath = os .path .relpath (root , dirName )
51- if relpath == '.' :
52- file_path = filename
53- else :
54- file_path = os .path .join (os .path .relpath (root , dirName ), filename )
55- file_paths [file_path ] = os .path .join (root , filename )
56-
57- # return all paths
58- return file_paths
59-
60- def _zip_workspace (self , workspace_path ):
61- if not workspace_path :
62- workspace_path = '.'
63- zip_file_name = os .path .basename (os .getcwd ()) + '.zip'
64- else :
65- zip_file_name = os .path .basename (workspace_path ) + '.zip'
66-
67- zip_file_path = os .path .join (workspace_path , zip_file_name )
68-
69- if os .path .exists (zip_file_path ):
70- self .logger .log ('Removing existing archive' )
71- os .remove (zip_file_path )
72-
73- file_paths = self ._retrieve_file_paths (workspace_path )
74-
75- self .logger .log ('Creating zip archive: %s' % zip_file_name )
76- zip_file = zipfile .ZipFile (zip_file_path , 'w' )
77-
78- bar = progressbar .ProgressBar (max_value = len (file_paths ))
79-
80- with zip_file :
81- i = 0
82- for relpath , abspath in file_paths .items ():
83- i += 1
84- self .logger .debug ('Adding %s to archive' % relpath )
85- zip_file .write (abspath , arcname = relpath )
86- bar .update (i )
87- bar .finish ()
88- self .logger .log ('\n Finished creating archive: %s' % zip_file_name )
89- return zip_file_path
90-
91- def _create_callback (self , encoder_obj ):
92- bar = progressbar .ProgressBar (max_value = encoder_obj .len )
93-
94- def callback (monitor ):
95- bar .update (monitor .bytes_read )
96- return callback
97-
98- def _upload_workspace (self , input_data ):
99- workspace_url = input_data .get ('workspaceUrl' )
100- workspace_path = input_data .get ('workspacePath' )
101- workspace_archive = input_data .get ('workspaceArchive' )
102- if (workspace_archive and workspace_path ) or (workspace_archive and workspace_url ) or (
103- workspace_path and workspace_url ):
104- raise click .UsageError ("Use either:\n \t --workspaceUrl to point repository URL"
105- "\n \t --workspacePath to point on project directory"
106- "\n \t --workspaceArchive to point on project ZIP archive"
107- "\n or neither to use current directory" )
108-
109- if workspace_url :
110- return # nothing to do
111-
112- if workspace_archive :
113- archive_path = os .path .abspath (workspace_archive )
114- else :
115- archive_path = self ._zip_workspace (workspace_path )
116-
117- file_name = os .path .basename (archive_path )
118- s3_upload_data = self ._get_upload_data (file_name )
119- bucket_name = s3_upload_data ['bucket_name' ]
120-
121- self .logger .log ('Uploading zipped workspace to S3' )
122-
123- files = {'file' : (archive_path , open (archive_path , 'rb' ))}
124- fields = OrderedDict (s3_upload_data ['fields' ])
125- fields .update (files )
126- s3_encoder = encoder .MultipartEncoder (fields = fields )
127- monitor = encoder .MultipartEncoderMonitor (s3_encoder , callback = self ._create_callback (s3_encoder ))
128- s3_response = requests .post (s3_upload_data ['url' ], data = monitor , headers = {'Content-Type' : monitor .content_type })
129- if not s3_response .ok :
130- raise S3UploadFailedException (s3_response )
131-
132- self .logger .log ('\n Uploading completed' )
133-
134- return 's3://{}/{}' .format (bucket_name , file_name )
13534
13635 def execute (self , json_ ):
137- workspace_url = self ._upload_workspace (json_ )
36+ workspace_url = self ._workspace_handler . upload_workspace (json_ )
13837 if workspace_url :
13938 json_ ['workspaceUrl' ] = workspace_url
14039
@@ -144,17 +43,13 @@ def execute(self, json_):
14443 "New experiment created with handle: {}" ,
14544 "Unknown error while creating the experiment" )
14645
147- def _get_upload_data (self , file_name ):
148- response = self .api .get ("/workspace/get_presigned_url" , params = {'workspaceName' : file_name })
149- if response .status_code == 404 :
150- raise PresignedUrlUnreachableException
151- if response .status_code == 403 :
152- raise PresignedUrlAccessDeniedException
153- return response .json ()
154-
15546
15647class CreateAndStartExperimentCommand (ExperimentCommand ):
15748 def execute (self , json_ ):
49+ workspace_url = self ._workspace_handler .upload_workspace (json_ )
50+ if workspace_url :
51+ json_ ['workspaceUrl' ] = workspace_url
52+
15853 response = self .api .post ("/experiments/create_and_start/" , json = json_ )
15954 self ._log_create_experiment (response ,
16055 "New experiment created and started with handle: {}" ,
0 commit comments