44import logging
55import os
66import time
7+ from collections import defaultdict
78from dataclasses import dataclass
8- from typing import Any , List , Optional
9+ from typing import Any , List , Optional , Set
910from urllib .parse import urlparse
1011
1112import requests as req
1415from graphdatascience .version import __version__
1516
1617
17- @dataclass (repr = True )
18+ @dataclass (repr = True , frozen = True )
1819class InstanceDetails :
1920 id : str
2021 name : str
@@ -31,7 +32,7 @@ def fromJson(cls, json: dict[str, Any]) -> InstanceDetails:
3132 )
3233
3334
34- @dataclass (repr = True )
35+ @dataclass (repr = True , frozen = True )
3536class InstanceSpecificDetails (InstanceDetails ):
3637 status : str
3738 connection_url : str
@@ -54,7 +55,7 @@ def fromJson(cls, json: dict[str, Any]) -> InstanceSpecificDetails:
5455 )
5556
5657
57- @dataclass (repr = True )
58+ @dataclass (repr = True , frozen = True )
5859class InstanceCreateDetails :
5960 id : str
6061 username : str
@@ -70,6 +71,35 @@ def from_json(cls, json: dict[str, Any]) -> InstanceCreateDetails:
7071 return cls (** {f .name : json [f .name ] for f in fields })
7172
7273
74+ @dataclass (repr = True , frozen = True )
75+ class TenantDetails :
76+ id : str
77+ ds_type : str
78+ regions_per_provider : dict [str , Set [str ]]
79+
80+ @classmethod
81+ def from_json (cls , json : dict [str , Any ]) -> TenantDetails :
82+ regions_per_provider = defaultdict (set )
83+ instance_types = set ()
84+ ds_type = None
85+
86+ for configs in json ["instance_configurations" ]:
87+ type = configs ["type" ]
88+ if type .split ("-" )[1 ] == "ds" :
89+ regions_per_provider [configs ["cloud_provider" ]].add (configs ["region" ])
90+ ds_type = type
91+ instance_types .add (configs ["type" ])
92+
93+ if not ds_type :
94+ raise RuntimeError (f"Tenant cannot create DS instances. Available instances are `{ instance_types } `." )
95+
96+ return cls (
97+ id = json ["id" ],
98+ ds_type = ds_type ,
99+ regions_per_provider = regions_per_provider ,
100+ )
101+
102+
73103class AuraApi :
74104 class AuraAuthToken :
75105 access_token : str
@@ -99,6 +129,7 @@ def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str]
99129 self ._token : Optional [AuraApi .AuraAuthToken ] = None
100130 self ._logger = logging .getLogger ()
101131 self ._tenant_id = tenant_id if tenant_id else self ._get_tenant_id ()
132+ self ._tenant_details : Optional [TenantDetails ] = None
102133
103134 @staticmethod
104135 def extract_id (uri : str ) -> str :
@@ -110,14 +141,14 @@ def extract_id(uri: str) -> str:
110141 return host .split ("." )[0 ].split ("-" )[0 ]
111142
112143 def create_instance (self , name : str , memory : str , cloud_provider : str , region : str ) -> InstanceCreateDetails :
113- # TODO should give more control here
144+ tenant_details = self .tenant_details ()
145+
114146 data = {
115147 "name" : name ,
116148 "memory" : memory ,
117149 "version" : "5" ,
118150 "region" : region ,
119- # TODO should be figured out from the tenant details in the future
120- "type" : self ._instance_type (),
151+ "type" : tenant_details .ds_type ,
121152 "tenant_id" : self ._tenant_id ,
122153 "cloud_provider" : cloud_provider ,
123154 }
@@ -216,6 +247,16 @@ def _get_tenant_id(self) -> str:
216247
217248 return raw_data [0 ]["id" ] # type: ignore
218249
250+ def tenant_details (self ) -> TenantDetails :
251+ if not self ._tenant_details :
252+ response = req .get (
253+ f"{ self ._base_uri } /v1/tenants/{ self ._tenant_id } " ,
254+ headers = self ._build_header (),
255+ )
256+ response .raise_for_status ()
257+ self ._tenant_details = TenantDetails .from_json (response .json ()["data" ])
258+ return self ._tenant_details
259+
219260 def _build_header (self ) -> dict [str , str ]:
220261 return {"Authorization" : f"Bearer { self ._auth_token ()} " , "User-agent" : f"neo4j-graphdatascience-v{ __version__ } " }
221262
0 commit comments