Skip to content

Commit 1863c4a

Browse files
committed
Fetch tenant details
for available regions and if tenant has ds capabilities
1 parent aa405c1 commit 1863c4a

File tree

2 files changed

+82
-7
lines changed

2 files changed

+82
-7
lines changed

graphdatascience/gds_session/aura_api.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import logging
55
import os
66
import time
7+
from collections import defaultdict
78
from dataclasses import dataclass
8-
from typing import Any, List, Optional
9+
from typing import Any, List, Optional, Set
910
from urllib.parse import urlparse
1011

1112
import requests as req
@@ -14,7 +15,7 @@
1415
from graphdatascience.version import __version__
1516

1617

17-
@dataclass(repr=True)
18+
@dataclass(repr=True, frozen=True)
1819
class InstanceDetails:
1920
id: str
2021
name: str
@@ -31,7 +32,7 @@ def fromJson(cls, json: dict[str, Any]) -> InstanceDetails:
3132
)
3233

3334

34-
@dataclass(repr=True)
35+
@dataclass(repr=True, frozen=True)
3536
class InstanceSpecificDetails(InstanceDetails):
3637
status: str
3738
connection_url: str
@@ -54,7 +55,7 @@ def fromJson(cls, json: dict[str, Any]) -> InstanceSpecificDetails:
5455
)
5556

5657

57-
@dataclass(repr=True)
58+
@dataclass(repr=True, frozen=True)
5859
class InstanceCreateDetails:
5960
id: str
6061
username: str
@@ -70,6 +71,35 @@ def from_json(cls, json: dict[str, Any]) -> InstanceCreateDetails:
7071
return cls(**{f.name: json[f.name] for f in fields})
7172

7273

74+
@dataclass(repr=True, frozen=True)
75+
class TenantDetails:
76+
id: str
77+
ds_type: str
78+
regions_per_provider: dict[str, Set[str]]
79+
80+
@classmethod
81+
def from_json(cls, json: dict[str, Any]) -> TenantDetails:
82+
regions_per_provider = defaultdict(set)
83+
instance_types = set()
84+
ds_type = None
85+
86+
for configs in json["instance_configurations"]:
87+
type = configs["type"]
88+
if type.split("-")[1] == "ds":
89+
regions_per_provider[configs["cloud_provider"]].add(configs["region"])
90+
ds_type = type
91+
instance_types.add(configs["type"])
92+
93+
if not ds_type:
94+
raise RuntimeError(f"Tenant cannot create DS instances. Available instances are `{instance_types}`.")
95+
96+
return cls(
97+
id=json["id"],
98+
ds_type=ds_type,
99+
regions_per_provider=regions_per_provider,
100+
)
101+
102+
73103
class AuraApi:
74104
class AuraAuthToken:
75105
access_token: str
@@ -99,6 +129,7 @@ def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str]
99129
self._token: Optional[AuraApi.AuraAuthToken] = None
100130
self._logger = logging.getLogger()
101131
self._tenant_id = tenant_id if tenant_id else self._get_tenant_id()
132+
self._tenant_details: Optional[TenantDetails] = None
102133

103134
@staticmethod
104135
def extract_id(uri: str) -> str:
@@ -110,14 +141,14 @@ def extract_id(uri: str) -> str:
110141
return host.split(".")[0].split("-")[0]
111142

112143
def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails:
113-
# TODO should give more control here
144+
tenant_details = self.tenant_details()
145+
114146
data = {
115147
"name": name,
116148
"memory": memory,
117149
"version": "5",
118150
"region": region,
119-
# TODO should be figured out from the tenant details in the future
120-
"type": self._instance_type(),
151+
"type": tenant_details.ds_type,
121152
"tenant_id": self._tenant_id,
122153
"cloud_provider": cloud_provider,
123154
}
@@ -216,6 +247,16 @@ def _get_tenant_id(self) -> str:
216247

217248
return raw_data[0]["id"] # type: ignore
218249

250+
def tenant_details(self) -> TenantDetails:
251+
if not self._tenant_details:
252+
response = req.get(
253+
f"{self._base_uri}/v1/tenants/{self._tenant_id}",
254+
headers=self._build_header(),
255+
)
256+
response.raise_for_status()
257+
self._tenant_details = TenantDetails.from_json(response.json()["data"])
258+
return self._tenant_details
259+
219260
def _build_header(self) -> dict[str, str]:
220261
return {"Authorization": f"Bearer {self._auth_token()}", "User-agent": f"neo4j-graphdatascience-v{__version__}"}
221262

graphdatascience/tests/unit/test_aura_api.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AuraApi,
1010
InstanceCreateDetails,
1111
InstanceSpecificDetails,
12+
TenantDetails,
1213
)
1314

1415

@@ -346,3 +347,36 @@ def test_parse_create_details() -> None:
346347
InstanceCreateDetails.from_json(
347348
{"id": "1", "username": "mats", "password": "1234", "connection_url": "url", "region": "fooo"}
348349
)
350+
351+
352+
def test_parse_tenant_details() -> None:
353+
details = TenantDetails.from_json(
354+
{
355+
"id": "42",
356+
"instance_configurations": [
357+
{"type": "enterprise-db", "region": "eu-west1", "cloud_provider": "aws"},
358+
{"type": "enterprise-ds", "region": "eu-west3", "cloud_provider": "gcp"},
359+
{"type": "enterprise-ds", "region": "us-central1", "cloud_provider": "aws"},
360+
{"type": "enterprise-ds", "region": "us-central3", "cloud_provider": "aws"},
361+
],
362+
}
363+
)
364+
365+
expected_details = TenantDetails(
366+
"42", ds_type="enterprise-ds", regions_per_provider={"gcp": {"eu-west3"}, "aws": {"us-central1", "us-central3"}}
367+
)
368+
assert details == expected_details
369+
370+
371+
def test_parse_non_ds_details() -> None:
372+
with pytest.raises(
373+
RuntimeError, match="Tenant cannot create DS instances. Available instances are `{'enterprise-db'}`."
374+
):
375+
TenantDetails.from_json(
376+
{
377+
"id": "42",
378+
"instance_configurations": [
379+
{"type": "enterprise-db", "region": "europe-west1", "cloud_provider": "aws"}
380+
],
381+
}
382+
)

0 commit comments

Comments
 (0)