|
1 | | -import logging |
2 | 1 | import os |
3 | 2 | import zipfile |
4 | 3 | from collections import OrderedDict |
|
7 | 6 | import progressbar |
8 | 7 | import requests |
9 | 8 | from requests_toolbelt.multipart import encoder |
| 9 | +from paperspace import logger as default_logger |
10 | 10 |
|
11 | 11 | from paperspace.exceptions import S3UploadFailedException, PresignedUrlUnreachableException, \ |
12 | 12 | PresignedUrlAccessDeniedException, PresignedUrlConnectionException |
13 | 13 |
|
14 | 14 |
|
15 | 15 | class S3WorkspaceHandler: |
16 | 16 | def __init__(self, experiments_api, logger=None): |
| 17 | + """ |
| 18 | +
|
| 19 | + :param experiments_api: paperspace.client.API |
| 20 | + :param logger: paperspace.logger |
| 21 | + """ |
17 | 22 | self.experiments_api = experiments_api |
18 | | - self.logger = logger or logging.getLogger() |
| 23 | + self.logger = logger or default_logger |
19 | 24 |
|
20 | 25 | def _retrieve_file_paths(self, dirName): |
21 | | - |
22 | 26 | # setup file paths variable |
23 | 27 | file_paths = {} |
24 | 28 | exclude = ['.git', '.idea', '.pytest_cache'] |
@@ -97,24 +101,32 @@ def upload_workspace(self, input_data): |
97 | 101 |
|
98 | 102 | file_name = os.path.basename(archive_path) |
99 | 103 | project_handle = input_data['projectHandle'] |
| 104 | + |
100 | 105 | s3_upload_data = self._get_upload_data(file_name, project_handle) |
| 106 | + |
101 | 107 | bucket_name = s3_upload_data['bucket_name'] |
| 108 | + s3_object_path = s3_upload_data['fields']['key'] |
102 | 109 |
|
103 | 110 | self.logger.log('Uploading zipped workspace to S3') |
104 | 111 |
|
| 112 | + self._upload(archive_path, s3_upload_data) |
| 113 | + |
| 114 | + self.logger.log('\nUploading completed') |
| 115 | + |
| 116 | + return 's3://{}/{}'.format(bucket_name, s3_object_path) |
| 117 | + |
| 118 | + def _upload(self, archive_path, s3_upload_data): |
105 | 119 | files = {'file': (archive_path, open(archive_path, 'rb'))} |
106 | 120 | fields = OrderedDict(s3_upload_data['fields']) |
107 | 121 | fields.update(files) |
| 122 | + |
108 | 123 | s3_encoder = encoder.MultipartEncoder(fields=fields) |
109 | 124 | monitor = encoder.MultipartEncoderMonitor(s3_encoder, callback=self._create_callback(s3_encoder)) |
110 | 125 | s3_response = requests.post(s3_upload_data['url'], data=monitor, headers={'Content-Type': monitor.content_type}) |
| 126 | + self.logger.debug("S3 upload response: {}".format(s3_response.headers)) |
111 | 127 | if not s3_response.ok: |
112 | 128 | raise S3UploadFailedException(s3_response) |
113 | 129 |
|
114 | | - self.logger.log('\nUploading completed') |
115 | | - |
116 | | - return 's3://{}/{}'.format(bucket_name, file_name) |
117 | | - |
118 | 130 | def _get_upload_data(self, file_name, project_handle): |
119 | 131 | response = self.experiments_api.get("/workspace/get_presigned_url", |
120 | 132 | params={'workspaceName': file_name, 'projectHandle': project_handle}) |
|
0 commit comments