Skip to content
This repository was archived by the owner on Aug 11, 2020. It is now read-only.

Commit 667452e

Browse files
committed
Add 'models list' command with filtering by experimentId
1 parent 35c6b83 commit 667452e

File tree

7 files changed

+1807
-961
lines changed

7 files changed

+1807
-961
lines changed

paperspace/cli/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from paperspace import constants, client, config
77
from paperspace.cli.common import api_key_option, del_if_value_is_none
88
from paperspace.cli.jobs import jobs_group
9+
from paperspace.cli.models import models_group
910
from paperspace.cli.projects import projects_group
1011
from paperspace.cli.types import ChoiceType, json_string
1112
from paperspace.cli.validators import validate_mutually_exclusive, validate_email
@@ -1054,3 +1055,4 @@ def version():
10541055

10551056
cli.add_command(jobs_group)
10561057
cli.add_command(projects_group)
1058+
cli.add_command(models_group)

paperspace/cli/models.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import click
2+
3+
from paperspace import client, config
4+
from paperspace.cli import common
5+
from paperspace.commands import models as models_commands
6+
7+
8+
@click.group("models", help="Manage models")
9+
def models_group():
10+
pass
11+
12+
13+
@models_group.command("list", help="List models with optional filtering")
14+
@click.option(
15+
"--experimentId",
16+
"experimentId",
17+
help="Use to filter jobs by experiment ID",
18+
)
19+
@common.api_key_option
20+
def list_jobs(api_key, **filters):
21+
common.del_if_value_is_none(filters)
22+
jobs_api = client.API(config.CONFIG_HOST, api_key=api_key)
23+
command = models_commands.ListModelsCommand(api=jobs_api)
24+
command.execute(filters)

paperspace/commands/models.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import pydoc
2+
3+
import terminaltables
4+
5+
from paperspace.utils import get_terminal_lines
6+
7+
from paperspace.commands import CommandBase
8+
9+
10+
class ListModelsCommand(CommandBase):
11+
def execute(self, filters):
12+
json_ = self._get_request_json(filters)
13+
params = {"limit": -1} # so the api returns full list without pagination
14+
response = self.api.get("/mlModels/getModelList/", json=json_, params=params)
15+
16+
try:
17+
models = self._get_models_list(response)
18+
except (ValueError, KeyError) as e:
19+
self.logger.log("Error while parsing response data: {}".format(e))
20+
else:
21+
self._log_models_list(models)
22+
23+
@staticmethod
24+
def _get_request_json(filters):
25+
experiment_id = filters.get("experimentId")
26+
if not experiment_id:
27+
return None
28+
29+
json_ = {"filter": {"where": {"and": [{"experimentId": experiment_id}]}}}
30+
return json_
31+
32+
def _get_models_list(self, response):
33+
if not response.ok:
34+
raise ValueError("Unknown error")
35+
36+
data = response.json()["modelList"]
37+
self.logger.debug(data)
38+
return data
39+
40+
def _log_models_list(self, model):
41+
if not model:
42+
self.logger.log("No models found")
43+
else:
44+
table_str = self._make_models_list_table(model)
45+
if len(table_str.splitlines()) > get_terminal_lines():
46+
pydoc.pager(table_str)
47+
else:
48+
self.logger.log(table_str)
49+
50+
@staticmethod
51+
def _make_models_list_table(models):
52+
data = [("Name", "ID", "Model Type", "Project ID", "Experiment ID")]
53+
for model in models:
54+
name = model.get("name")
55+
id_ = model.get("id")
56+
model_type = model.get("modelType")
57+
project_id = model.get("projectId")
58+
experiment_id = model.get("experimentId")
59+
data.append((name, id_, model_type, project_id, experiment_id))
60+
61+
ascii_table = terminaltables.AsciiTable(data)
62+
table_string = ascii_table.table
63+
return table_string
64+

paperspace/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
def main():
1111
if len(sys.argv) >= 2 and sys.argv[1] in ('experiments', 'deployments', 'machines', 'login', 'logout', 'version',
12-
'projects', 'jobs'):
12+
'projects', 'jobs', 'models'):
1313
cli(sys.argv[1:])
1414

1515
args = sys.argv[:]

0 commit comments

Comments
 (0)